Spaces:
Runtime error
Runtime error
Upload 34 files
Browse files- .gitattributes +0 -4
- App_main.py +109 -0
- qs_run.sh +32 -0
- readme 2.md +117 -0
- readme_cn.md +116 -0
- requirements.txt +26 -0
- tools/__init__.py +0 -0
- tools/ins_seg/analysis_tools/browse_dataset_mmdet_mmyolo_pl.py +266 -0
- tools/ins_seg/analysis_tools/dataset_analysis.py +518 -0
- tools/ins_seg/dataset_converters/cityscapes.py +152 -0
- tools/ins_seg/dataset_converters/whu_building_convert.py +143 -0
- tools/ins_seg/sam/sam_cls/get_sam_cls_crops.py +120 -0
- tools/ins_seg/sam/sam_cls/get_sam_cls_metrics.py +170 -0
- tools/ins_seg/sam/sam_cls/segment_anything/__init__.py +15 -0
- tools/ins_seg/sam/sam_cls/segment_anything/automatic_mask_generator.py +372 -0
- tools/ins_seg/sam/sam_cls/segment_anything/build_sam.py +107 -0
- tools/ins_seg/sam/sam_cls/segment_anything/modeling/__init__.py +11 -0
- tools/ins_seg/sam/sam_cls/segment_anything/modeling/common.py +43 -0
- tools/ins_seg/sam/sam_cls/segment_anything/modeling/image_encoder.py +395 -0
- tools/ins_seg/sam/sam_cls/segment_anything/modeling/mask_decoder.py +176 -0
- tools/ins_seg/sam/sam_cls/segment_anything/modeling/prompt_encoder.py +214 -0
- tools/ins_seg/sam/sam_cls/segment_anything/modeling/sam.py +174 -0
- tools/ins_seg/sam/sam_cls/segment_anything/modeling/transformer.py +240 -0
- tools/ins_seg/sam/sam_cls/segment_anything/predictor.py +269 -0
- tools/ins_seg/sam/sam_cls/segment_anything/utils/__init__.py +5 -0
- tools/ins_seg/sam/sam_cls/segment_anything/utils/amg.py +346 -0
- tools/ins_seg/sam/sam_cls/segment_anything/utils/onnx.py +144 -0
- tools/ins_seg/sam/sam_cls/segment_anything/utils/transforms.py +102 -0
- tools/predict.py +46 -0
- tools/test.py +43 -0
- tools/train.py +48 -0
- visualizer/test_img.jpg +0 -0
.gitattributes
CHANGED
@@ -33,7 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
data/WHU/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
|
37 |
-
data/WHU/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
|
38 |
-
mmpretrain/annotations/WHU_building_test.json filter=lfs diff=lfs merge=lfs -text
|
39 |
-
mmpretrain/annotations/WHU_building_train.json filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
App_main.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import mmcv
|
3 |
+
import mmengine
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from mmengine import Config, get
|
7 |
+
from mmengine.dataset import Compose
|
8 |
+
from mmpl.registry import MODELS, VISUALIZERS
|
9 |
+
from mmpl.utils import register_all_modules
|
10 |
+
register_all_modules()
|
11 |
+
# os.system('nvidia-smi')
|
12 |
+
# os.system('ls /usr/local')
|
13 |
+
# os.system('pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117')
|
14 |
+
# os.system('pip install -U openmim')
|
15 |
+
# os.system('mim install mmcv==2.0.0')
|
16 |
+
# os.system('mim install mmengine')
|
17 |
+
|
18 |
+
import gradio as gr
|
19 |
+
import torch
|
20 |
+
|
21 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
22 |
+
|
23 |
+
|
24 |
+
def construct_sample(img, pipeline):
|
25 |
+
img = np.array(img)[:, :, ::-1]
|
26 |
+
inputs = {
|
27 |
+
'ori_shape': img.shape[:2],
|
28 |
+
'img': img,
|
29 |
+
}
|
30 |
+
pipeline = Compose(pipeline)
|
31 |
+
sample = pipeline(inputs)
|
32 |
+
return sample
|
33 |
+
|
34 |
+
def build_model(cp, model_cfg):
|
35 |
+
model_cpkt = torch.load(cp, map_location='cpu')
|
36 |
+
model = MODELS.build(model_cfg)
|
37 |
+
model.load_state_dict(model_cpkt, strict=True)
|
38 |
+
model.to(device=device)
|
39 |
+
model.eval()
|
40 |
+
return model
|
41 |
+
|
42 |
+
|
43 |
+
# Function for building extraction
|
44 |
+
def inference_func(ori_img, cp):
|
45 |
+
checkpoint = f'pretrain/{cp}_anchor.pth'
|
46 |
+
cfg = f'configs/huggingface/rsprompter_anchor_{cp}_config.py'
|
47 |
+
cfg = Config.fromfile(cfg)
|
48 |
+
sample = construct_sample(ori_img, cfg.predict_pipeline)
|
49 |
+
sample['inputs'] = [sample['inputs']]
|
50 |
+
sample['data_samples'] = [sample['data_samples']]
|
51 |
+
|
52 |
+
print('Use: ', device)
|
53 |
+
model = build_model(checkpoint, cfg.model_cfg)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
pred_results = model.predict_step(sample, batch_idx=0)
|
57 |
+
|
58 |
+
cfg.visualizer.setdefault('save_dir', 'visualizer')
|
59 |
+
visualizer = VISUALIZERS.build(cfg.visualizer)
|
60 |
+
|
61 |
+
data_sample = pred_results[0]
|
62 |
+
img = np.array(ori_img).copy()
|
63 |
+
out_file = 'visualizer/test_img.jpg'
|
64 |
+
mmengine.mkdir_or_exist(os.path.dirname(out_file))
|
65 |
+
visualizer.add_datasample(
|
66 |
+
'test_img',
|
67 |
+
img,
|
68 |
+
draw_gt=False,
|
69 |
+
data_sample=data_sample,
|
70 |
+
show=False,
|
71 |
+
wait_time=0.01,
|
72 |
+
pred_score_thr=0.4,
|
73 |
+
out_file=out_file
|
74 |
+
)
|
75 |
+
img_bytes = get(out_file)
|
76 |
+
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
77 |
+
return img
|
78 |
+
|
79 |
+
title = "RSPrompter"
|
80 |
+
description = "Gradio demo for RSPrompter. Upload image from WHU building dataset, NWPU dataset, or SSDD Dataset or click any one of the examples, " \
|
81 |
+
"Then select the prompt model, and click \"Submit\" and wait for the result. \n \n" \
|
82 |
+
"Paper: RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"
|
83 |
+
|
84 |
+
article = "<p style='text-align: center'><a href='https://kyanchen.github.io/RSPrompter/' target='_blank'>RSPrompter Project " \
|
85 |
+
"Page</a></p> "
|
86 |
+
|
87 |
+
files = glob.glob('examples/NWPU*')
|
88 |
+
examples = [[f, f.split('/')[-1].split('_')[0]] for f in files]
|
89 |
+
|
90 |
+
with gr.Blocks() as demo:
|
91 |
+
image_input = gr.Image(type='pil', label='Input Img')
|
92 |
+
# with gr.Row().style(equal_height=True):
|
93 |
+
# image_LR_output = gr.outputs.Image(label='LR Img', type='numpy')
|
94 |
+
image_output = gr.Image(label='Segment Result', type='numpy')
|
95 |
+
with gr.Row():
|
96 |
+
checkpoint = gr.Radio(['WHU', 'NWPU', 'SSDD'], label='Checkpoint')
|
97 |
+
|
98 |
+
io = gr.Interface(fn=inference_func,
|
99 |
+
inputs=[image_input, checkpoint],
|
100 |
+
outputs=[image_output],
|
101 |
+
title=title,
|
102 |
+
description=description,
|
103 |
+
article=article,
|
104 |
+
allow_flagging='auto',
|
105 |
+
examples=examples,
|
106 |
+
cache_examples=True,
|
107 |
+
layout="grid"
|
108 |
+
)
|
109 |
+
io.launch()
|
qs_run.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
source ~/.bashrc
|
3 |
+
conda activate torch2mmcv2 # torch1mmcv1 torch1mmcv2 torch2mmcv1 torch2mmcv2
|
4 |
+
pip install albumentations
|
5 |
+
pip install importlib_metadata
|
6 |
+
pip install --upgrade mmengine
|
7 |
+
pip install instaboostfast
|
8 |
+
#pip install deepspeed
|
9 |
+
# pip install anypackage
|
10 |
+
# yum install which
|
11 |
+
# source /opt/rh/devtoolset-9/enable
|
12 |
+
# mim install mmcv>=2.0.0rc4
|
13 |
+
|
14 |
+
cd /mnt/search01/usr/chenkeyan/codes/lightning_framework
|
15 |
+
#TORCH_DISTRIBUTED_DEBUG=DETAIL
|
16 |
+
case $# in
|
17 |
+
0)
|
18 |
+
python tools/train.py
|
19 |
+
;;
|
20 |
+
1)
|
21 |
+
python tools/train.py --config $1
|
22 |
+
;;
|
23 |
+
2)
|
24 |
+
python tools/train.py --config $1 --ckpt-path $2
|
25 |
+
;;
|
26 |
+
esac
|
27 |
+
# TORCH_DISTRIBUTED_DEBUG=DETAIL
|
28 |
+
#python train.py
|
29 |
+
#python -m torch.distributed.launch --nproc_per_node=$GPU_NUM --nnodes=$WORLD_SIZE --node_rank=$RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --use_env train.py
|
30 |
+
#python -m torch.distributed.launch --nproc_per_node=$GPU_NUM --nnodes=$WORLD_SIZE --node_rank=$RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --use_env train_pipe.py
|
31 |
+
# juicesync src dst
|
32 |
+
# juicefs rmr your_dir
|
readme 2.md
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model
|
2 |
+
|
3 |
+
English | [简体中文](/readme_cn.md)
|
4 |
+
|
5 |
+
This is the pytorch implement of our paper "RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"
|
6 |
+
|
7 |
+
|
8 |
+
[Project Page](https://kyanchen.github.io/RSPrompter/) $\cdot$ [PDF Download](https://arxiv.org/abs/2306.16269) $\cdot$ [HuggingFace Demo](https://huggingface.co/spaces/KyanChen/RSPrompter)
|
9 |
+
|
10 |
+
|
11 |
+
## 0. Environment Setup
|
12 |
+
|
13 |
+
### 0.1 Create a virtual environment
|
14 |
+
|
15 |
+
```shell
|
16 |
+
conda create -n RSPrompter python=3.10
|
17 |
+
```
|
18 |
+
|
19 |
+
### 0.2 Activate the virtual environment
|
20 |
+
```sehll
|
21 |
+
conda activate RSPrompter
|
22 |
+
```
|
23 |
+
|
24 |
+
### 0.3 Install pytorch
|
25 |
+
Version of 1.x is also work, but the version of 2.x is recommended.
|
26 |
+
```shell
|
27 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu117
|
28 |
+
```
|
29 |
+
|
30 |
+
### 0.3 [Optional] Install pytorch
|
31 |
+
```shell
|
32 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
|
33 |
+
```
|
34 |
+
|
35 |
+
### 0.4 Install mmcv
|
36 |
+
Version of 2.x is recommended.
|
37 |
+
```shell
|
38 |
+
pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch2.0/index.html
|
39 |
+
```
|
40 |
+
Please refer to [installation documentation](https://mmcv.readthedocs.io/en/latest/get_started/installation.html) for more detailed installation.
|
41 |
+
|
42 |
+
### 0.5 Install other dependencies
|
43 |
+
```shell
|
44 |
+
pip install -r requirements.txt
|
45 |
+
```
|
46 |
+
|
47 |
+
## 1. Data Preparation
|
48 |
+
|
49 |
+
### 1.1 Dataset
|
50 |
+
|
51 |
+
#### WHU Dataset
|
52 |
+
WHU dataset can be downloaded from [WHU](https://aistudio.baidu.com/aistudio/datasetdetail/56502). After downloading, put the dataset into the **data** folder, which contains some image examples.
|
53 |
+
|
54 |
+
#### NWPU Dataset
|
55 |
+
NWPU dataset can be downloaded from [NWPU](https://aistudio.baidu.com/aistudio/datasetdetail/52812). After downloading, put the dataset into the **data** folder, which contains some image examples.
|
56 |
+
|
57 |
+
#### SSDD Dataset
|
58 |
+
SSDD dataset can be downloaded from [SSDD](https://aistudio.baidu.com/aistudio/datasetdetail/100924). After downloading, put the dataset into the **data** folder, which contains some image examples.
|
59 |
+
|
60 |
+
#### 1.2 Split the dataset into train and test set
|
61 |
+
The dataset split files and annotation files are provided in this project, which are stored in the **data/*/annotations** folder in COCO annotation format.
|
62 |
+
|
63 |
+
## 2. Model Training
|
64 |
+
|
65 |
+
### 2.1 Train SAM-based model
|
66 |
+
|
67 |
+
#### 2.1.1 Config file
|
68 |
+
The config file is located in the **configs/rsprompter** folder, which can be modified according to the situation. The config file provides three models: SAM-seg, SAM-det, and RSPrompter.
|
69 |
+
|
70 |
+
#### 2.1.2 Train
|
71 |
+
Some parameters of the training can also be modified in the above configuration file. The main modification of the parameters in trainer_cfg, such as single-card multi-card training, etc., for specific configuration modifications, please refer to the Trainer of Pytorch Lightning.
|
72 |
+
```shell
|
73 |
+
python tools/train.py
|
74 |
+
```
|
75 |
+
|
76 |
+
### 2.2 [Optional] Train other models
|
77 |
+
#### 2.2.1 Config file
|
78 |
+
The config file is located in the **configs/rsprompter** folder, which provides only the configuration of Mask R-CNN and Mask2Former. The configuration of other models can refer to these two configuration files and the model config in MMDetection.
|
79 |
+
|
80 |
+
#### 2.2.2 Train
|
81 |
+
Modify the config path in **tools/train.py** and then run
|
82 |
+
```shell
|
83 |
+
python tools/train.py
|
84 |
+
```
|
85 |
+
|
86 |
+
## 3. Model Evaluation
|
87 |
+
The config file is located in the **configs/rsprompter** folder, which can be modified according to the situation.
|
88 |
+
When the val_evaluator and val_loader are configured in the configuration file, the model will automatically evaluate the model on the validation set during model training, and the evaluation results will be uploaded to Wandb and can be viewed in Wandb.
|
89 |
+
If you need to perform offline evaluation on the test set, you need to configure the test_evaluator and test_loader in the configuration file, as well as the config and ckpt-path paths in **tools/test.py**, and then run
|
90 |
+
```shell
|
91 |
+
python tools/test.py
|
92 |
+
```
|
93 |
+
|
94 |
+
## 4. [Optional] Model Visualization
|
95 |
+
The config file is located in the **configs/rsprompter** folder, which can be modified according to the situation. You can modify the parameters of DetVisualizationHook and DetLocalVisualizer in the configuration file, as well as the config and ckpt-path paths in **tools/predict.py**, and then run
|
96 |
+
```shell
|
97 |
+
python tools/predict.py
|
98 |
+
```
|
99 |
+
|
100 |
+
## 5. [Optional] Model Download
|
101 |
+
This project provides the model weights of RSPrompter-anchor, which are located in [huggingface space](https://huggingface.co/spaces/KyanChen/RSPrompter/tree/main/pretrain)
|
102 |
+
|
103 |
+
## 6. [Optional] Citation
|
104 |
+
If you find this project useful for your research, please cite our paper.
|
105 |
+
|
106 |
+
If you have any other questions, please contact me!!!
|
107 |
+
|
108 |
+
```
|
109 |
+
@misc{chen2023rsprompter,
|
110 |
+
title={RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model},
|
111 |
+
author={Keyan Chen and Chenyang Liu and Hao Chen and Haotian Zhang and Wenyuan Li and Zhengxia Zou and Zhenwei Shi},
|
112 |
+
year={2023},
|
113 |
+
eprint={2306.16269},
|
114 |
+
archivePrefix={arXiv},
|
115 |
+
primaryClass={cs.CV}
|
116 |
+
}
|
117 |
+
```
|
readme_cn.md
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model
|
2 |
+
|
3 |
+
[English](/readme.md) | 简体中文
|
4 |
+
|
5 |
+
|
6 |
+
本项目是论文"RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model"的Pytorch实现
|
7 |
+
|
8 |
+
|
9 |
+
[项目主页](https://kyanchen.github.io/RSPrompter/) $\cdot$ [PDF下载](https://arxiv.org/abs/2306.16269) $\cdot$ [HuggingFace 样例](https://huggingface.co/spaces/KyanChen/RSPrompter)
|
10 |
+
|
11 |
+
|
12 |
+
## 0. 环境准备
|
13 |
+
### 0.1 建立虚拟环境
|
14 |
+
```shell
|
15 |
+
conda create -n RSPrompter python=3.10
|
16 |
+
```
|
17 |
+
### 0.2 激活虚拟环境
|
18 |
+
```sehll
|
19 |
+
conda activate RSPrompter
|
20 |
+
```
|
21 |
+
### 0.3 安装pytorch
|
22 |
+
1.x版本也可以,但是建议使用2.x版本
|
23 |
+
```shell
|
24 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu117
|
25 |
+
```
|
26 |
+
### 0.3 [可选]安装pytorch
|
27 |
+
```shell
|
28 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
|
29 |
+
```
|
30 |
+
### 0.4 安装mmcv
|
31 |
+
建议2.x版本
|
32 |
+
```shell
|
33 |
+
pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch2.0/index.html
|
34 |
+
```
|
35 |
+
更多安装信息请参考[安装文档](https://mmcv.readthedocs.io/zh_CN/latest/get_started/installation.html)
|
36 |
+
### 0.5 安装其他依赖
|
37 |
+
```shell
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
## 1. 数据准备
|
42 |
+
|
43 |
+
### 1.1 数据集
|
44 |
+
|
45 |
+
#### WHU数据集
|
46 |
+
WHU数据集可以从[WHU](https://aistudio.baidu.com/aistudio/datasetdetail/56502)下载,下载后将数据集放到**data**文件夹中,该文件夹放入了一些图像示例。
|
47 |
+
|
48 |
+
#### NWPU数据集
|
49 |
+
NWPU数据集可以从[NWPU](https://aistudio.baidu.com/aistudio/datasetdetail/52812)下载,下载后将数据集放到**data**文件夹中,该文件夹放入了一些图像示例。
|
50 |
+
|
51 |
+
#### SSDD数据集
|
52 |
+
SSDD数据集可以从[SSDD](https://aistudio.baidu.com/aistudio/datasetdetail/100924)下载,下载后将数据集放到**data**文件夹中,该文件夹放入了一些图像示例。
|
53 |
+
|
54 |
+
### 1.2 划分训练测试集
|
55 |
+
在本项目中已提供论文中的数据集划分文件和标注文件,以COCO标注格式存储,位于**data/*/annotations**文件夹中。
|
56 |
+
|
57 |
+
## 2. 模型训练
|
58 |
+
|
59 |
+
### 2.1 训练SAM-based模型
|
60 |
+
|
61 |
+
#### 2.1.1 配置文件
|
62 |
+
配置文件位于**configs/rsprompter**文件夹中,可以依据情况修改该文件中的参数,提供了SAM-seg,SAM-det,RSPrompter三种模型的配置文件。
|
63 |
+
|
64 |
+
#### 2.1.2 训练
|
65 |
+
训练的一些参数配置也可以在上述配置文件中修改,主要修改trainer_cfg中的参数,例如单卡多卡训练等,具体配置修改参考Pytorch Lightning的Trainer。
|
66 |
+
```shell
|
67 |
+
python tools/train.py
|
68 |
+
```
|
69 |
+
|
70 |
+
|
71 |
+
### 2.2 [可选] 训练其他模型
|
72 |
+
#### 2.2.1 配置文件
|
73 |
+
配置文件位于**configs/rsprompter**文件夹中,仅提供了Mask R-CNN和Mask2Former的配置,其他模型的配置可以参考这两个配置文件和MMDetection中的模型config进行修改。
|
74 |
+
|
75 |
+
#### 2.2.2 训练
|
76 |
+
修改**tools/train.py**中的config路径,然后运行
|
77 |
+
```shell
|
78 |
+
python tools/train.py
|
79 |
+
```
|
80 |
+
|
81 |
+
|
82 |
+
## 3. 模型评测
|
83 |
+
|
84 |
+
模型配置文件位于**configs/rsprompter**文件夹中,可以依据情况修改该文件中的参数。
|
85 |
+
当配置了该文件中val_evaluator和val_loader,在模型训练时,会自动进行模型在验证集上的评测,评测结果会上传到Wandb中,可以在Wandb中查看。
|
86 |
+
如果需要在测试集上进行离线评测,需要配置配置文件中的test_evaluator和test_loader,以及**tools/test.py**中的config和ckpt-path路径,然后运行
|
87 |
+
```shell
|
88 |
+
python tools/test.py
|
89 |
+
```
|
90 |
+
|
91 |
+
## 4. [可选]结果可视化
|
92 |
+
模型配置文件位于**configs/rsprompter**文件夹中,可以依据情况修改该文件中的**DetVisualizationHook**和**DetLocalVisualizer**的参数,
|
93 |
+
以及**tools/predict.py**中的config和ckpt-path路径,然后运行
|
94 |
+
```shell
|
95 |
+
python tools/predict.py
|
96 |
+
```
|
97 |
+
|
98 |
+
|
99 |
+
## 5. [可选]模型下载
|
100 |
+
本项目提供了RSPrompter-anchor的模型权重,位于[huggingface space](https://huggingface.co/spaces/KyanChen/RSPrompter/tree/main/pretrain)中
|
101 |
+
|
102 |
+
## 6. [可选]引用
|
103 |
+
如果您认为本项目对您的研究有所帮助,请引用我们的论文.
|
104 |
+
|
105 |
+
如果您有其他问题,请联系我!!!
|
106 |
+
|
107 |
+
```
|
108 |
+
@misc{chen2023rsprompter,
|
109 |
+
title={RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model},
|
110 |
+
author={Keyan Chen and Chenyang Liu and Hao Chen and Haotian Zhang and Wenyuan Li and Zhengxia Zou and Zhenwei Shi},
|
111 |
+
year={2023},
|
112 |
+
eprint={2306.16269},
|
113 |
+
archivePrefix={arXiv},
|
114 |
+
primaryClass={cs.CV}
|
115 |
+
}
|
116 |
+
```
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
mmcv>=2.0.0rc4
|
3 |
+
importlib_metadata
|
4 |
+
mat4py
|
5 |
+
|
6 |
+
torchvision
|
7 |
+
lightning==2.0.1
|
8 |
+
mmengine
|
9 |
+
openmim
|
10 |
+
wandb
|
11 |
+
redis
|
12 |
+
scikit-learn
|
13 |
+
scikit-image
|
14 |
+
tenacity
|
15 |
+
torchmetrics
|
16 |
+
tensorboardx
|
17 |
+
transformers
|
18 |
+
ipdb
|
19 |
+
prettytable
|
20 |
+
einops
|
21 |
+
|
22 |
+
imageio
|
23 |
+
pycocotools
|
24 |
+
shapely
|
25 |
+
terminaltables
|
26 |
+
albumentations
|
tools/__init__.py
ADDED
File without changes
|
tools/ins_seg/analysis_tools/browse_dataset_mmdet_mmyolo_pl.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os.path as osp
|
3 |
+
import sys
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import mmcv
|
8 |
+
import numpy as np
|
9 |
+
from mmdet.models.utils import mask2ndarray
|
10 |
+
from mmdet.structures.bbox import BaseBoxes
|
11 |
+
from mmengine.config import Config, DictAction
|
12 |
+
from mmengine.dataset import Compose
|
13 |
+
from mmengine.registry import init_default_scope
|
14 |
+
from mmengine.utils import ProgressBar
|
15 |
+
from mmengine.visualization import Visualizer
|
16 |
+
|
17 |
+
from mmyolo.registry import DATASETS, VISUALIZERS
|
18 |
+
|
19 |
+
|
20 |
+
# TODO: Support for printing the change in key of results
|
21 |
+
def parse_args():
|
22 |
+
parser = argparse.ArgumentParser(description='Browse a dataset')
|
23 |
+
parser.add_argument('--config', default='configs/ins_seg/seg_maskrcnn_whu_bs8_config.py', help='train config file path')
|
24 |
+
parser.add_argument(
|
25 |
+
'--phase',
|
26 |
+
'-p',
|
27 |
+
default='val',
|
28 |
+
type=str,
|
29 |
+
choices=['train', 'test', 'val'],
|
30 |
+
help='phase of dataset to visualize, accept "train" "test" and "val".'
|
31 |
+
' Defaults to "train".')
|
32 |
+
parser.add_argument(
|
33 |
+
'--mode',
|
34 |
+
'-m',
|
35 |
+
default='transformed',
|
36 |
+
type=str,
|
37 |
+
choices=['original', 'transformed', 'pipeline'],
|
38 |
+
help='display mode; display original pictures or '
|
39 |
+
'transformed pictures or comparison pictures. "original" '
|
40 |
+
'means show images load from disk; "transformed" means '
|
41 |
+
'to show images after transformed; "pipeline" means show all '
|
42 |
+
'the intermediate images. Defaults to "transformed".')
|
43 |
+
parser.add_argument(
|
44 |
+
'--out-dir',
|
45 |
+
default='output',
|
46 |
+
type=str,
|
47 |
+
help='If there is no display interface, you can save it.')
|
48 |
+
parser.add_argument('--not-show', default=False, action='store_true')
|
49 |
+
parser.add_argument(
|
50 |
+
'--show-number',
|
51 |
+
'-n',
|
52 |
+
type=int,
|
53 |
+
default=sys.maxsize,
|
54 |
+
help='number of images selected to visualize, '
|
55 |
+
'must bigger than 0. if the number is bigger than length '
|
56 |
+
'of dataset, show all the images in dataset; '
|
57 |
+
'default "sys.maxsize", show all images in dataset')
|
58 |
+
parser.add_argument(
|
59 |
+
'--show-interval',
|
60 |
+
'-i',
|
61 |
+
type=float,
|
62 |
+
default=3,
|
63 |
+
help='the interval of show (s)')
|
64 |
+
parser.add_argument(
|
65 |
+
'--cfg-options',
|
66 |
+
nargs='+',
|
67 |
+
action=DictAction,
|
68 |
+
help='override some settings in the used config, the key-value pair '
|
69 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
70 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
71 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
72 |
+
'Note that the quotation marks are necessary and that no white space '
|
73 |
+
'is allowed.')
|
74 |
+
args = parser.parse_args()
|
75 |
+
return args
|
76 |
+
|
77 |
+
|
78 |
+
def _get_adaptive_scale(img_shape: Tuple[int, int],
|
79 |
+
min_scale: float = 0.3,
|
80 |
+
max_scale: float = 3.0) -> float:
|
81 |
+
"""Get adaptive scale according to image shape.
|
82 |
+
|
83 |
+
The target scale depends on the the short edge length of the image. If the
|
84 |
+
short edge length equals 224, the output is 1.0. And output linear
|
85 |
+
scales according the short edge length. You can also specify the minimum
|
86 |
+
scale and the maximum scale to limit the linear scale.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
img_shape (Tuple[int, int]): The shape of the canvas image.
|
90 |
+
min_scale (int): The minimum scale. Defaults to 0.3.
|
91 |
+
max_scale (int): The maximum scale. Defaults to 3.0.
|
92 |
+
Returns:
|
93 |
+
int: The adaptive scale.
|
94 |
+
"""
|
95 |
+
short_edge_length = min(img_shape)
|
96 |
+
scale = short_edge_length / 224.
|
97 |
+
return min(max(scale, min_scale), max_scale)
|
98 |
+
|
99 |
+
|
100 |
+
def make_grid(imgs, names):
|
101 |
+
"""Concat list of pictures into a single big picture, align height here."""
|
102 |
+
visualizer = Visualizer.get_current_instance()
|
103 |
+
ori_shapes = [img.shape[:2] for img in imgs]
|
104 |
+
max_height = int(max(img.shape[0] for img in imgs) * 1.1)
|
105 |
+
min_width = min(img.shape[1] for img in imgs)
|
106 |
+
horizontal_gap = min_width // 10
|
107 |
+
img_scale = _get_adaptive_scale((max_height, min_width))
|
108 |
+
|
109 |
+
texts = []
|
110 |
+
text_positions = []
|
111 |
+
start_x = 0
|
112 |
+
for i, img in enumerate(imgs):
|
113 |
+
pad_height = (max_height - img.shape[0]) // 2
|
114 |
+
pad_width = horizontal_gap // 2
|
115 |
+
# make border
|
116 |
+
imgs[i] = cv2.copyMakeBorder(
|
117 |
+
img,
|
118 |
+
pad_height,
|
119 |
+
max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2),
|
120 |
+
pad_width,
|
121 |
+
pad_width,
|
122 |
+
cv2.BORDER_CONSTANT,
|
123 |
+
value=(255, 255, 255))
|
124 |
+
texts.append(f'{"execution: "}{i}\n{names[i]}\n{ori_shapes[i]}')
|
125 |
+
text_positions.append(
|
126 |
+
[start_x + img.shape[1] // 2 + pad_width, max_height])
|
127 |
+
start_x += img.shape[1] + horizontal_gap
|
128 |
+
|
129 |
+
display_img = np.concatenate(imgs, axis=1)
|
130 |
+
visualizer.set_image(display_img)
|
131 |
+
img_scale = _get_adaptive_scale(display_img.shape[:2])
|
132 |
+
visualizer.draw_texts(
|
133 |
+
texts,
|
134 |
+
positions=np.array(text_positions),
|
135 |
+
font_sizes=img_scale * 7,
|
136 |
+
colors='black',
|
137 |
+
horizontal_alignments='center',
|
138 |
+
font_families='monospace')
|
139 |
+
return visualizer.get_image()
|
140 |
+
|
141 |
+
|
142 |
+
class InspectCompose(Compose):
|
143 |
+
"""Compose multiple transforms sequentially.
|
144 |
+
|
145 |
+
And record "img" field of all results in one list.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, transforms, intermediate_imgs):
|
149 |
+
super().__init__(transforms=transforms)
|
150 |
+
self.intermediate_imgs = intermediate_imgs
|
151 |
+
|
152 |
+
def __call__(self, data):
|
153 |
+
if 'img' in data:
|
154 |
+
self.intermediate_imgs.append({
|
155 |
+
'name': 'original',
|
156 |
+
'img': data['img'].copy()
|
157 |
+
})
|
158 |
+
self.ptransforms = [
|
159 |
+
self.transforms[i] for i in range(len(self.transforms) - 1)
|
160 |
+
]
|
161 |
+
for t in self.ptransforms:
|
162 |
+
data = t(data)
|
163 |
+
# Keep the same meta_keys in the PackDetInputs
|
164 |
+
self.transforms[-1].meta_keys = [key for key in data]
|
165 |
+
data_sample = self.transforms[-1](data)
|
166 |
+
if data is None:
|
167 |
+
return None
|
168 |
+
if 'img' in data:
|
169 |
+
self.intermediate_imgs.append({
|
170 |
+
'name':
|
171 |
+
t.__class__.__name__,
|
172 |
+
'dataset_sample':
|
173 |
+
data_sample['data_samples']
|
174 |
+
})
|
175 |
+
return data
|
176 |
+
|
177 |
+
|
178 |
+
def main():
|
179 |
+
args = parse_args()
|
180 |
+
cfg = Config.fromfile(args.config)
|
181 |
+
if args.cfg_options is not None:
|
182 |
+
cfg.merge_from_dict(args.cfg_options)
|
183 |
+
|
184 |
+
init_default_scope(cfg.get('default_scope', 'mmpl'))
|
185 |
+
|
186 |
+
dataset_cfg = cfg.get('datamodule_cfg').get(args.phase + '_loader').get('dataset')
|
187 |
+
dataset = DATASETS.build(dataset_cfg)
|
188 |
+
|
189 |
+
# self added
|
190 |
+
vis_backends = [dict(type='mmdet.LocalVisBackend')]
|
191 |
+
visualizer = dict(
|
192 |
+
type='mmdet.DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
193 |
+
|
194 |
+
visualizer = VISUALIZERS.build(visualizer)
|
195 |
+
visualizer.dataset_meta = dataset.metainfo
|
196 |
+
|
197 |
+
intermediate_imgs = []
|
198 |
+
|
199 |
+
if not hasattr(dataset, 'pipeline'):
|
200 |
+
# for dataset_wrapper
|
201 |
+
dataset = dataset.dataset
|
202 |
+
|
203 |
+
# TODO: The dataset wrapper occasion is not considered here
|
204 |
+
dataset.pipeline = InspectCompose(dataset.pipeline.transforms,
|
205 |
+
intermediate_imgs)
|
206 |
+
|
207 |
+
# init visualization image number
|
208 |
+
assert args.show_number > 0
|
209 |
+
display_number = min(args.show_number, len(dataset))
|
210 |
+
|
211 |
+
progress_bar = ProgressBar(display_number)
|
212 |
+
for i, item in zip(range(display_number), dataset):
|
213 |
+
image_i = []
|
214 |
+
result_i = [result['dataset_sample'] for result in intermediate_imgs]
|
215 |
+
for k, datasample in enumerate(result_i):
|
216 |
+
image = datasample.img
|
217 |
+
gt_instances = datasample.gt_instances
|
218 |
+
image = image[..., [2, 1, 0]] # bgr to rgb
|
219 |
+
gt_bboxes = gt_instances.get('bboxes', None)
|
220 |
+
if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes):
|
221 |
+
gt_instances.bboxes = gt_bboxes.tensor
|
222 |
+
gt_masks = gt_instances.get('masks', None)
|
223 |
+
if gt_masks is not None:
|
224 |
+
masks = mask2ndarray(gt_masks)
|
225 |
+
gt_instances.masks = masks.astype(bool)
|
226 |
+
datasample.gt_instances = gt_instances
|
227 |
+
# get filename from dataset or just use index as filename
|
228 |
+
visualizer.add_datasample(
|
229 |
+
'result',
|
230 |
+
image,
|
231 |
+
datasample,
|
232 |
+
draw_pred=False,
|
233 |
+
draw_gt=True,
|
234 |
+
show=False)
|
235 |
+
image_show = visualizer.get_image()
|
236 |
+
image_i.append(image_show)
|
237 |
+
|
238 |
+
if args.mode == 'original':
|
239 |
+
image = image_i[0]
|
240 |
+
elif args.mode == 'transformed':
|
241 |
+
image = image_i[-1]
|
242 |
+
else:
|
243 |
+
image = make_grid([result for result in image_i],
|
244 |
+
[result['name'] for result in intermediate_imgs])
|
245 |
+
|
246 |
+
if hasattr(datasample, 'img_path'):
|
247 |
+
filename = osp.basename(datasample.img_path)
|
248 |
+
else:
|
249 |
+
# some dataset have not image path
|
250 |
+
filename = f'{i}.jpg'
|
251 |
+
out_file = osp.join(args.out_dir,
|
252 |
+
filename) if args.out_dir is not None else None
|
253 |
+
|
254 |
+
if out_file is not None:
|
255 |
+
mmcv.imwrite(image[..., ::-1], out_file)
|
256 |
+
|
257 |
+
if not args.not_show:
|
258 |
+
visualizer.show(
|
259 |
+
image, win_name=filename, wait_time=args.show_interval)
|
260 |
+
|
261 |
+
intermediate_imgs.clear()
|
262 |
+
progress_bar.update()
|
263 |
+
|
264 |
+
|
265 |
+
if __name__ == '__main__':
|
266 |
+
main()
|
tools/ins_seg/analysis_tools/dataset_analysis.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os.path
|
4 |
+
from statistics import median
|
5 |
+
|
6 |
+
import matplotlib.patches as mpatches
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
from mmengine.config import Config
|
10 |
+
from mmengine.registry import init_default_scope
|
11 |
+
from mmengine.utils import ProgressBar
|
12 |
+
from prettytable import PrettyTable
|
13 |
+
|
14 |
+
from mmyolo.registry import DATASETS
|
15 |
+
from mmyolo.utils.misc import show_data_classes
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
description='Distribution of categories and bbox instances')
|
21 |
+
parser.add_argument('--config', default='configs/ins_seg/seg_sam_queryprompt_ssdd_bs2_last_config.py', help='config file path')
|
22 |
+
parser.add_argument(
|
23 |
+
'--val-dataset',
|
24 |
+
default=False,
|
25 |
+
action='store_true',
|
26 |
+
help='The default train_dataset.'
|
27 |
+
'To change it to val_dataset, enter "--val-dataset"')
|
28 |
+
parser.add_argument(
|
29 |
+
'--class-name',
|
30 |
+
default=None,
|
31 |
+
type=str,
|
32 |
+
help='Display specific class, e.g., "bicycle"')
|
33 |
+
parser.add_argument(
|
34 |
+
'--area-rule',
|
35 |
+
default=None,
|
36 |
+
type=int,
|
37 |
+
nargs='+',
|
38 |
+
help='Redefine area rules,but no more than three numbers.'
|
39 |
+
' e.g., 30 70 125')
|
40 |
+
parser.add_argument(
|
41 |
+
'--func',
|
42 |
+
default='show_bbox_num',
|
43 |
+
type=str,
|
44 |
+
choices=[
|
45 |
+
'show_bbox_num', 'show_bbox_wh', 'show_bbox_wh_ratio',
|
46 |
+
'show_bbox_area'
|
47 |
+
],
|
48 |
+
help='Dataset analysis function selection.')
|
49 |
+
parser.add_argument(
|
50 |
+
'--out-dir',
|
51 |
+
default='./results/dataset_analysis',
|
52 |
+
type=str,
|
53 |
+
help='Output directory of dataset analysis visualization results,'
|
54 |
+
' Save in "./dataset_analysis/" by default')
|
55 |
+
args = parser.parse_args()
|
56 |
+
return args
|
57 |
+
|
58 |
+
|
59 |
+
def show_bbox_num(cfg, out_dir, fig_set, class_name, class_num):
|
60 |
+
"""Display the distribution map of categories and number of bbox
|
61 |
+
instances."""
|
62 |
+
print('\n\nDrawing bbox_num figure:')
|
63 |
+
# Draw designs
|
64 |
+
fig = plt.figure(
|
65 |
+
figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
|
66 |
+
plt.bar(class_name, class_num, align='center')
|
67 |
+
|
68 |
+
# Draw titles, labels and so on
|
69 |
+
for x, y in enumerate(class_num):
|
70 |
+
plt.text(x, y, '%s' % y, ha='center', fontsize=fig_set['fontsize'] + 3)
|
71 |
+
plt.xticks(rotation=fig_set['xticks_angle'])
|
72 |
+
plt.xlabel('Category Name')
|
73 |
+
plt.ylabel('Num of instances')
|
74 |
+
plt.title(cfg.dataset_type)
|
75 |
+
|
76 |
+
# Save figure
|
77 |
+
if not os.path.exists(out_dir):
|
78 |
+
os.makedirs(out_dir)
|
79 |
+
out_name = fig_set['out_name']
|
80 |
+
fig.savefig(
|
81 |
+
f'{out_dir}/{out_name}_bbox_num.jpg',
|
82 |
+
bbox_inches='tight',
|
83 |
+
pad_inches=0.1) # Save Image
|
84 |
+
plt.close()
|
85 |
+
print(f'End and save in {out_dir}/{out_name}_bbox_num.jpg')
|
86 |
+
|
87 |
+
|
88 |
+
def show_bbox_wh(out_dir, fig_set, class_bbox_w, class_bbox_h, class_name):
|
89 |
+
"""Display the width and height distribution of categories and bbox
|
90 |
+
instances."""
|
91 |
+
print('\n\nDrawing bbox_wh figure:')
|
92 |
+
# Draw designs
|
93 |
+
fig, ax = plt.subplots(
|
94 |
+
figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
|
95 |
+
|
96 |
+
# Set the position of the map and label on the x-axis
|
97 |
+
positions_w = list(range(0, 12 * len(class_name), 12))
|
98 |
+
positions_h = list(range(6, 12 * len(class_name), 12))
|
99 |
+
positions_x_label = list(range(3, 12 * len(class_name) + 1, 12))
|
100 |
+
ax.violinplot(
|
101 |
+
class_bbox_w, positions_w, showmeans=True, showmedians=True, widths=4)
|
102 |
+
ax.violinplot(
|
103 |
+
class_bbox_h, positions_h, showmeans=True, showmedians=True, widths=4)
|
104 |
+
|
105 |
+
# Draw titles, labels and so on
|
106 |
+
plt.xticks(rotation=fig_set['xticks_angle'])
|
107 |
+
plt.ylabel('The width or height of bbox')
|
108 |
+
plt.xlabel('Class name')
|
109 |
+
plt.title('Width or height distribution of classes and bbox instances')
|
110 |
+
|
111 |
+
# Draw the max, min and median of wide data in violin chart
|
112 |
+
for i in range(len(class_bbox_w)):
|
113 |
+
plt.text(
|
114 |
+
positions_w[i],
|
115 |
+
median(class_bbox_w[i]),
|
116 |
+
f'{"%.2f" % median(class_bbox_w[i])}',
|
117 |
+
ha='center',
|
118 |
+
fontsize=fig_set['fontsize'])
|
119 |
+
plt.text(
|
120 |
+
positions_w[i],
|
121 |
+
max(class_bbox_w[i]),
|
122 |
+
f'{"%.2f" % max(class_bbox_w[i])}',
|
123 |
+
ha='center',
|
124 |
+
fontsize=fig_set['fontsize'])
|
125 |
+
plt.text(
|
126 |
+
positions_w[i],
|
127 |
+
min(class_bbox_w[i]),
|
128 |
+
f'{"%.2f" % min(class_bbox_w[i])}',
|
129 |
+
ha='center',
|
130 |
+
fontsize=fig_set['fontsize'])
|
131 |
+
|
132 |
+
# Draw the max, min and median of height data in violin chart
|
133 |
+
for i in range(len(positions_h)):
|
134 |
+
plt.text(
|
135 |
+
positions_h[i],
|
136 |
+
median(class_bbox_h[i]),
|
137 |
+
f'{"%.2f" % median(class_bbox_h[i])}',
|
138 |
+
ha='center',
|
139 |
+
fontsize=fig_set['fontsize'])
|
140 |
+
plt.text(
|
141 |
+
positions_h[i],
|
142 |
+
max(class_bbox_h[i]),
|
143 |
+
f'{"%.2f" % max(class_bbox_h[i])}',
|
144 |
+
ha='center',
|
145 |
+
fontsize=fig_set['fontsize'])
|
146 |
+
plt.text(
|
147 |
+
positions_h[i],
|
148 |
+
min(class_bbox_h[i]),
|
149 |
+
f'{"%.2f" % min(class_bbox_h[i])}',
|
150 |
+
ha='center',
|
151 |
+
fontsize=fig_set['fontsize'])
|
152 |
+
|
153 |
+
# Draw Legend
|
154 |
+
plt.setp(ax, xticks=positions_x_label, xticklabels=class_name)
|
155 |
+
labels = ['bbox_w', 'bbox_h']
|
156 |
+
colors = ['steelblue', 'darkorange']
|
157 |
+
patches = [
|
158 |
+
mpatches.Patch(color=colors[i], label=f'{labels[i]:s}')
|
159 |
+
for i in range(len(colors))
|
160 |
+
]
|
161 |
+
ax = plt.gca()
|
162 |
+
box = ax.get_position()
|
163 |
+
ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
|
164 |
+
ax.legend(loc='upper center', handles=patches, ncol=2)
|
165 |
+
|
166 |
+
# Save figure
|
167 |
+
if not os.path.exists(out_dir):
|
168 |
+
os.makedirs(out_dir)
|
169 |
+
out_name = fig_set['out_name']
|
170 |
+
fig.savefig(
|
171 |
+
f'{out_dir}/{out_name}_bbox_wh.jpg',
|
172 |
+
bbox_inches='tight',
|
173 |
+
pad_inches=0.1) # Save Image
|
174 |
+
plt.close()
|
175 |
+
print(f'End and save in {out_dir}/{out_name}_bbox_wh.jpg')
|
176 |
+
|
177 |
+
|
178 |
+
def show_bbox_wh_ratio(out_dir, fig_set, class_name, class_bbox_ratio):
|
179 |
+
"""Display the distribution map of category and bbox instance width and
|
180 |
+
height ratio."""
|
181 |
+
print('\n\nDrawing bbox_wh_ratio figure:')
|
182 |
+
# Draw designs
|
183 |
+
fig, ax = plt.subplots(
|
184 |
+
figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
|
185 |
+
|
186 |
+
# Set the position of the map and label on the x-axis
|
187 |
+
positions = list(range(0, 6 * len(class_name), 6))
|
188 |
+
ax.violinplot(
|
189 |
+
class_bbox_ratio,
|
190 |
+
positions,
|
191 |
+
showmeans=True,
|
192 |
+
showmedians=True,
|
193 |
+
widths=5)
|
194 |
+
|
195 |
+
# Draw titles, labels and so on
|
196 |
+
plt.xticks(rotation=fig_set['xticks_angle'])
|
197 |
+
plt.ylabel('Ratio of width to height of bbox')
|
198 |
+
plt.xlabel('Class name')
|
199 |
+
plt.title('Width to height ratio distribution of class and bbox instances')
|
200 |
+
|
201 |
+
# Draw the max, min and median of wide data in violin chart
|
202 |
+
for i in range(len(class_bbox_ratio)):
|
203 |
+
plt.text(
|
204 |
+
positions[i],
|
205 |
+
median(class_bbox_ratio[i]),
|
206 |
+
f'{"%.2f" % median(class_bbox_ratio[i])}',
|
207 |
+
ha='center',
|
208 |
+
fontsize=fig_set['fontsize'])
|
209 |
+
plt.text(
|
210 |
+
positions[i],
|
211 |
+
max(class_bbox_ratio[i]),
|
212 |
+
f'{"%.2f" % max(class_bbox_ratio[i])}',
|
213 |
+
ha='center',
|
214 |
+
fontsize=fig_set['fontsize'])
|
215 |
+
plt.text(
|
216 |
+
positions[i],
|
217 |
+
min(class_bbox_ratio[i]),
|
218 |
+
f'{"%.2f" % min(class_bbox_ratio[i])}',
|
219 |
+
ha='center',
|
220 |
+
fontsize=fig_set['fontsize'])
|
221 |
+
|
222 |
+
# Set the position of the map and label on the x-axis
|
223 |
+
plt.setp(ax, xticks=positions, xticklabels=class_name)
|
224 |
+
|
225 |
+
# Save figure
|
226 |
+
if not os.path.exists(out_dir):
|
227 |
+
os.makedirs(out_dir)
|
228 |
+
out_name = fig_set['out_name']
|
229 |
+
fig.savefig(
|
230 |
+
f'{out_dir}/{out_name}_bbox_ratio.jpg',
|
231 |
+
bbox_inches='tight',
|
232 |
+
pad_inches=0.1) # Save Image
|
233 |
+
plt.close()
|
234 |
+
print(f'End and save in {out_dir}/{out_name}_bbox_ratio.jpg')
|
235 |
+
|
236 |
+
|
237 |
+
def show_bbox_area(out_dir, fig_set, area_rule, class_name, bbox_area_num):
|
238 |
+
"""Display the distribution map of category and bbox instance area based on
|
239 |
+
the rules of large, medium and small objects."""
|
240 |
+
print('\n\nDrawing bbox_area figure:')
|
241 |
+
# Set the direct distance of each label and the width of each histogram
|
242 |
+
# Set the required labels and colors
|
243 |
+
positions = np.arange(0, 2 * len(class_name), 2)
|
244 |
+
width = 0.4
|
245 |
+
labels = ['Small', 'Mediun', 'Large', 'Huge']
|
246 |
+
colors = ['#438675', '#F7B469', '#6BA6DA', '#913221']
|
247 |
+
|
248 |
+
# Draw designs
|
249 |
+
fig = plt.figure(
|
250 |
+
figsize=(fig_set['figsize'][0], fig_set['figsize'][1]), dpi=300)
|
251 |
+
for i in range(len(area_rule) - 1):
|
252 |
+
area_num = [bbox_area_num[idx][i] for idx in range(len(class_name))]
|
253 |
+
plt.bar(
|
254 |
+
positions + width * i,
|
255 |
+
area_num,
|
256 |
+
width,
|
257 |
+
label=labels[i],
|
258 |
+
color=colors[i])
|
259 |
+
for idx, (x, y) in enumerate(zip(positions.tolist(), area_num)):
|
260 |
+
plt.text(
|
261 |
+
x + width * i,
|
262 |
+
y,
|
263 |
+
y,
|
264 |
+
ha='center',
|
265 |
+
fontsize=fig_set['fontsize'] - 1)
|
266 |
+
|
267 |
+
# Draw titles, labels and so on
|
268 |
+
plt.xticks(rotation=fig_set['xticks_angle'])
|
269 |
+
plt.xticks(positions + width * ((len(area_rule) - 2) / 2), class_name)
|
270 |
+
plt.ylabel('Class Area')
|
271 |
+
plt.xlabel('Class Name')
|
272 |
+
plt.title(
|
273 |
+
'Area and number of large, medium and small objects of each class')
|
274 |
+
|
275 |
+
# Set and Draw Legend
|
276 |
+
patches = [
|
277 |
+
mpatches.Patch(color=colors[i], label=f'{labels[i]:s}')
|
278 |
+
for i in range(len(area_rule) - 1)
|
279 |
+
]
|
280 |
+
ax = plt.gca()
|
281 |
+
box = ax.get_position()
|
282 |
+
ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
|
283 |
+
ax.legend(loc='upper center', handles=patches, ncol=len(area_rule) - 1)
|
284 |
+
|
285 |
+
# Save figure
|
286 |
+
if not os.path.exists(out_dir):
|
287 |
+
os.makedirs(out_dir)
|
288 |
+
out_name = fig_set['out_name']
|
289 |
+
fig.savefig(
|
290 |
+
f'{out_dir}/{out_name}_bbox_area.jpg',
|
291 |
+
bbox_inches='tight',
|
292 |
+
pad_inches=0.1) # Save Image
|
293 |
+
plt.close()
|
294 |
+
print(f'End and save in {out_dir}/{out_name}_bbox_area.jpg')
|
295 |
+
|
296 |
+
|
297 |
+
def show_class_list(classes, class_num):
|
298 |
+
"""Print the data of the class obtained by the current run."""
|
299 |
+
print('\n\nThe information obtained is as follows:')
|
300 |
+
class_info = PrettyTable()
|
301 |
+
class_info.title = 'Information of dataset class'
|
302 |
+
# List Print Settings
|
303 |
+
# If the quantity is too large, 25 rows will be displayed in each column
|
304 |
+
if len(classes) < 25:
|
305 |
+
class_info.add_column('Class name', classes)
|
306 |
+
class_info.add_column('Bbox num', class_num)
|
307 |
+
elif len(classes) % 25 != 0 and len(classes) > 25:
|
308 |
+
col_num = int(len(classes) / 25) + 1
|
309 |
+
class_nums = class_num.tolist()
|
310 |
+
class_name_list = list(classes)
|
311 |
+
for i in range(0, (col_num * 25) - len(classes)):
|
312 |
+
class_name_list.append('')
|
313 |
+
class_nums.append('')
|
314 |
+
for i in range(0, len(class_name_list), 25):
|
315 |
+
class_info.add_column('Class name', class_name_list[i:i + 25])
|
316 |
+
class_info.add_column('Bbox num', class_nums[i:i + 25])
|
317 |
+
|
318 |
+
# Align display data to the left
|
319 |
+
class_info.align['Class name'] = 'l'
|
320 |
+
class_info.align['Bbox num'] = 'l'
|
321 |
+
print(class_info)
|
322 |
+
|
323 |
+
|
324 |
+
def show_data_list(args, area_rule):
|
325 |
+
"""Print run setup information."""
|
326 |
+
print('\n\nPrint current running information:')
|
327 |
+
data_info = PrettyTable()
|
328 |
+
data_info.title = 'Dataset information'
|
329 |
+
# Print the corresponding information according to the settings
|
330 |
+
if args.val_dataset is False:
|
331 |
+
data_info.add_column('Dataset type', ['train_dataset'])
|
332 |
+
elif args.val_dataset is True:
|
333 |
+
data_info.add_column('Dataset type', ['val_dataset'])
|
334 |
+
if args.class_name is None:
|
335 |
+
data_info.add_column('Class name', ['All classes'])
|
336 |
+
else:
|
337 |
+
data_info.add_column('Class name', [args.class_name])
|
338 |
+
if args.func is None:
|
339 |
+
data_info.add_column('Function', ['All function'])
|
340 |
+
else:
|
341 |
+
data_info.add_column('Function', [args.func])
|
342 |
+
data_info.add_column('Area rule', [area_rule])
|
343 |
+
|
344 |
+
print(data_info)
|
345 |
+
|
346 |
+
|
347 |
+
def main():
|
348 |
+
args = parse_args()
|
349 |
+
cfg = Config.fromfile(args.config)
|
350 |
+
|
351 |
+
init_default_scope(cfg.get('default_scope', 'mmpl'))
|
352 |
+
|
353 |
+
def replace_pipeline_to_none(cfg):
|
354 |
+
"""Recursively iterate over all dataset(or datasets) and set their
|
355 |
+
pipelines to none.Datasets are mean ConcatDataset.
|
356 |
+
|
357 |
+
Recursively terminates only when all dataset(or datasets) have been
|
358 |
+
traversed
|
359 |
+
"""
|
360 |
+
|
361 |
+
if cfg.get('dataset', None) is None and cfg.get('datasets',
|
362 |
+
None) is None:
|
363 |
+
return
|
364 |
+
dataset = cfg.dataset if cfg.get('dataset', None) else cfg.datasets
|
365 |
+
if isinstance(dataset, list):
|
366 |
+
for item in dataset:
|
367 |
+
item.pipeline = None
|
368 |
+
elif dataset.get('pipeline', None):
|
369 |
+
dataset.pipeline = None
|
370 |
+
else:
|
371 |
+
replace_pipeline_to_none(dataset)
|
372 |
+
|
373 |
+
# 1.Build Dataset
|
374 |
+
dataset_cfg = cfg.get('datamodule_cfg')
|
375 |
+
if args.val_dataset is False:
|
376 |
+
replace_pipeline_to_none(dataset_cfg.train_loader)
|
377 |
+
dataset = DATASETS.build(dataset_cfg.train_loader.dataset)
|
378 |
+
else:
|
379 |
+
replace_pipeline_to_none(dataset_cfg.val_loader)
|
380 |
+
dataset = DATASETS.build(dataset_cfg.val_loader.dataset)
|
381 |
+
|
382 |
+
# 2.Prepare data
|
383 |
+
# Drawing settings
|
384 |
+
fig_all_set = {
|
385 |
+
'figsize': [35, 18],
|
386 |
+
'fontsize': int(10 - 0.08 * len(dataset.metainfo['classes'])),
|
387 |
+
'xticks_angle': 70,
|
388 |
+
'out_name': cfg.dataset_type
|
389 |
+
}
|
390 |
+
fig_one_set = {
|
391 |
+
'figsize': [15, 10],
|
392 |
+
'fontsize': 10,
|
393 |
+
'xticks_angle': 0,
|
394 |
+
'out_name': args.class_name
|
395 |
+
}
|
396 |
+
|
397 |
+
# Call the category name and save address
|
398 |
+
if args.class_name is None:
|
399 |
+
classes = dataset.metainfo['classes']
|
400 |
+
classes_idx = [i for i in range(len(classes))]
|
401 |
+
fig_set = fig_all_set
|
402 |
+
elif args.class_name in dataset.metainfo['classes']:
|
403 |
+
classes = [args.class_name]
|
404 |
+
classes_idx = [dataset.metainfo['classes'].index(args.class_name)]
|
405 |
+
fig_set = fig_one_set
|
406 |
+
else:
|
407 |
+
data_classes = dataset.metainfo['classes']
|
408 |
+
show_data_classes(data_classes)
|
409 |
+
raise RuntimeError(f'Expected args.class_name to be one of the list,'
|
410 |
+
f'but got "{args.class_name}"')
|
411 |
+
|
412 |
+
# Building Area Rules
|
413 |
+
if args.area_rule is None:
|
414 |
+
area_rule = [0, 32, 96, 1e5]
|
415 |
+
elif args.area_rule and len(args.area_rule) <= 3:
|
416 |
+
area_rules = [0] + args.area_rule + [1e5]
|
417 |
+
area_rule = sorted(area_rules)
|
418 |
+
else:
|
419 |
+
raise RuntimeError(
|
420 |
+
f'Expected the "{args.area_rule}" to be e.g. 30 60 120, '
|
421 |
+
'and no more than three numbers.')
|
422 |
+
|
423 |
+
# Build arrays or lists to store data for each category
|
424 |
+
class_num = np.zeros((len(classes), ), dtype=np.int64)
|
425 |
+
class_bbox = [[] for _ in classes]
|
426 |
+
class_name = []
|
427 |
+
class_bbox_w = []
|
428 |
+
class_bbox_h = []
|
429 |
+
class_bbox_ratio = []
|
430 |
+
bbox_area_num = []
|
431 |
+
instance_num = []
|
432 |
+
|
433 |
+
show_data_list(args, area_rule)
|
434 |
+
# Get the quantity and bbox data corresponding to each category
|
435 |
+
print('\nRead the information of each picture in the dataset:')
|
436 |
+
progress_bar = ProgressBar(len(dataset))
|
437 |
+
|
438 |
+
counts_instances = 0
|
439 |
+
for index in range(len(dataset)):
|
440 |
+
instances = dataset[index]['instances']
|
441 |
+
# if len(instances) > 100:
|
442 |
+
# counts_instances += 1
|
443 |
+
# # continue
|
444 |
+
# labels = [instance['bbox_label'] for instance in instances]
|
445 |
+
# counts = np.bincount(labels)
|
446 |
+
# label_id = np.argmax(counts)
|
447 |
+
# # Harbor Large_Vehicle Small_Vehicle ship
|
448 |
+
# print(f'the class is {dataset.metainfo["classes"][label_id]}')
|
449 |
+
# print('The number of bboxes in the picture is greater than 100')
|
450 |
+
instance_num.append(len(instances))
|
451 |
+
for instance in dataset[index]['instances']:
|
452 |
+
if instance[
|
453 |
+
'bbox_label'] in classes_idx and args.class_name is None:
|
454 |
+
class_num[instance['bbox_label']] += 1
|
455 |
+
class_bbox[instance['bbox_label']].append(instance['bbox'])
|
456 |
+
elif instance['bbox_label'] in classes_idx and args.class_name:
|
457 |
+
class_num[0] += 1
|
458 |
+
class_bbox[0].append(instance['bbox'])
|
459 |
+
progress_bar.update()
|
460 |
+
show_class_list(classes, class_num)
|
461 |
+
print(f'The number of bboxes in the picture is greater than 120: {counts_instances}')
|
462 |
+
# Get the width, height and area of bbox corresponding to each category
|
463 |
+
print('\nRead bbox information in each class:')
|
464 |
+
progress_bar_classes = ProgressBar(len(classes))
|
465 |
+
for idx, (classes, classes_idx) in enumerate(zip(classes, classes_idx)):
|
466 |
+
bbox = np.array(class_bbox[idx])
|
467 |
+
bbox_area_nums = np.zeros((len(area_rule) - 1, ), dtype=np.int64)
|
468 |
+
if len(bbox) > 0:
|
469 |
+
bbox_wh = bbox[:, 2:4] - bbox[:, 0:2]
|
470 |
+
bbox_ratio = bbox_wh[:, 0] / bbox_wh[:, 1]
|
471 |
+
bbox_area = bbox_wh[:, 0] * bbox_wh[:, 1]
|
472 |
+
class_bbox_w.append(bbox_wh[:, 0].tolist())
|
473 |
+
class_bbox_h.append(bbox_wh[:, 1].tolist())
|
474 |
+
class_bbox_ratio.append(bbox_ratio.tolist())
|
475 |
+
|
476 |
+
# The area rule, there is an section between two numbers
|
477 |
+
for i in range(len(area_rule) - 1):
|
478 |
+
bbox_area_nums[i] = np.logical_and(
|
479 |
+
bbox_area >= area_rule[i]**2,
|
480 |
+
bbox_area < area_rule[i + 1]**2).sum()
|
481 |
+
elif len(bbox) == 0:
|
482 |
+
class_bbox_w.append([0])
|
483 |
+
class_bbox_h.append([0])
|
484 |
+
class_bbox_ratio.append([0])
|
485 |
+
|
486 |
+
class_name.append(classes)
|
487 |
+
bbox_area_num.append(bbox_area_nums.tolist())
|
488 |
+
progress_bar_classes.update()
|
489 |
+
|
490 |
+
# 3.draw Dataset Information
|
491 |
+
if args.func is None:
|
492 |
+
show_bbox_num(cfg, args.out_dir, fig_set, class_name, class_num)
|
493 |
+
show_bbox_wh(args.out_dir, fig_set, class_bbox_w, class_bbox_h,
|
494 |
+
class_name)
|
495 |
+
show_bbox_wh_ratio(args.out_dir, fig_set, class_name, class_bbox_ratio)
|
496 |
+
show_bbox_area(args.out_dir, fig_set, area_rule, class_name,
|
497 |
+
bbox_area_num)
|
498 |
+
elif args.func == 'show_bbox_num':
|
499 |
+
show_bbox_num(cfg, args.out_dir, fig_set, class_name, class_num)
|
500 |
+
print('num_instances_info:')
|
501 |
+
print('max num_instances=', max(instance_num))
|
502 |
+
print('min num_instances=', min(instance_num))
|
503 |
+
print('mean num_instances=', np.mean(instance_num))
|
504 |
+
elif args.func == 'show_bbox_wh':
|
505 |
+
show_bbox_wh(args.out_dir, fig_set, class_bbox_w, class_bbox_h,
|
506 |
+
class_name)
|
507 |
+
elif args.func == 'show_bbox_wh_ratio':
|
508 |
+
show_bbox_wh_ratio(args.out_dir, fig_set, class_name, class_bbox_ratio)
|
509 |
+
elif args.func == 'show_bbox_area':
|
510 |
+
show_bbox_area(args.out_dir, fig_set, area_rule, class_name,
|
511 |
+
bbox_area_num)
|
512 |
+
else:
|
513 |
+
raise RuntimeError(
|
514 |
+
'Please enter the correct func name, e.g., show_bbox_num')
|
515 |
+
|
516 |
+
|
517 |
+
if __name__ == '__main__':
|
518 |
+
main()
|
tools/ins_seg/dataset_converters/cityscapes.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
import cityscapesscripts.helpers.labels as CSLabels
|
6 |
+
import mmcv
|
7 |
+
import numpy as np
|
8 |
+
import pycocotools.mask as maskUtils
|
9 |
+
from mmengine.fileio import dump
|
10 |
+
from mmengine.utils import (Timer, mkdir_or_exist, track_parallel_progress,
|
11 |
+
track_progress)
|
12 |
+
|
13 |
+
|
14 |
+
def collect_files(img_dir, gt_dir):
|
15 |
+
suffix = 'leftImg8bit.png'
|
16 |
+
files = []
|
17 |
+
for img_file in glob.glob(osp.join(img_dir, '**/*.png')):
|
18 |
+
assert img_file.endswith(suffix), img_file
|
19 |
+
inst_file = gt_dir + img_file[
|
20 |
+
len(img_dir):-len(suffix)] + 'gtFine_instanceIds.png'
|
21 |
+
# Note that labelIds are not converted to trainId for seg map
|
22 |
+
segm_file = gt_dir + img_file[
|
23 |
+
len(img_dir):-len(suffix)] + 'gtFine_labelIds.png'
|
24 |
+
files.append((img_file, inst_file, segm_file))
|
25 |
+
assert len(files), f'No images found in {img_dir}'
|
26 |
+
print(f'Loaded {len(files)} images from {img_dir}')
|
27 |
+
|
28 |
+
return files
|
29 |
+
|
30 |
+
|
31 |
+
def collect_annotations(files, nproc=1):
|
32 |
+
print('Loading annotation images')
|
33 |
+
if nproc > 1:
|
34 |
+
images = track_parallel_progress(load_img_info, files, nproc=nproc)
|
35 |
+
else:
|
36 |
+
images = track_progress(load_img_info, files)
|
37 |
+
|
38 |
+
return images
|
39 |
+
|
40 |
+
|
41 |
+
def load_img_info(files):
|
42 |
+
img_file, inst_file, segm_file = files
|
43 |
+
inst_img = mmcv.imread(inst_file, 'unchanged')
|
44 |
+
# ids < 24 are stuff labels (filtering them first is about 5% faster)
|
45 |
+
unique_inst_ids = np.unique(inst_img[inst_img >= 24])
|
46 |
+
anno_info = []
|
47 |
+
for inst_id in unique_inst_ids:
|
48 |
+
# For non-crowd annotations, inst_id // 1000 is the label_id
|
49 |
+
# Crowd annotations have <1000 instance ids
|
50 |
+
label_id = inst_id // 1000 if inst_id >= 1000 else inst_id
|
51 |
+
label = CSLabels.id2label[label_id]
|
52 |
+
if not label.hasInstances or label.ignoreInEval:
|
53 |
+
continue
|
54 |
+
|
55 |
+
category_id = label.id
|
56 |
+
iscrowd = int(inst_id < 1000)
|
57 |
+
mask = np.asarray(inst_img == inst_id, dtype=np.uint8, order='F')
|
58 |
+
mask_rle = maskUtils.encode(mask[:, :, None])[0]
|
59 |
+
|
60 |
+
area = maskUtils.area(mask_rle)
|
61 |
+
# convert to COCO style XYWH format
|
62 |
+
bbox = maskUtils.toBbox(mask_rle)
|
63 |
+
|
64 |
+
# for json encoding
|
65 |
+
mask_rle['counts'] = mask_rle['counts'].decode()
|
66 |
+
|
67 |
+
anno = dict(
|
68 |
+
iscrowd=iscrowd,
|
69 |
+
category_id=category_id,
|
70 |
+
bbox=bbox.tolist(),
|
71 |
+
area=area.tolist(),
|
72 |
+
segmentation=mask_rle)
|
73 |
+
anno_info.append(anno)
|
74 |
+
video_name = osp.basename(osp.dirname(img_file))
|
75 |
+
img_info = dict(
|
76 |
+
# remove img_prefix for filename
|
77 |
+
file_name=osp.join(video_name, osp.basename(img_file)),
|
78 |
+
height=inst_img.shape[0],
|
79 |
+
width=inst_img.shape[1],
|
80 |
+
anno_info=anno_info,
|
81 |
+
segm_file=osp.join(video_name, osp.basename(segm_file)))
|
82 |
+
|
83 |
+
return img_info
|
84 |
+
|
85 |
+
|
86 |
+
def cvt_annotations(image_infos, out_json_name):
|
87 |
+
out_json = dict()
|
88 |
+
img_id = 0
|
89 |
+
ann_id = 0
|
90 |
+
out_json['images'] = []
|
91 |
+
out_json['categories'] = []
|
92 |
+
out_json['annotations'] = []
|
93 |
+
for image_info in image_infos:
|
94 |
+
image_info['id'] = img_id
|
95 |
+
anno_infos = image_info.pop('anno_info')
|
96 |
+
out_json['images'].append(image_info)
|
97 |
+
for anno_info in anno_infos:
|
98 |
+
anno_info['image_id'] = img_id
|
99 |
+
anno_info['id'] = ann_id
|
100 |
+
out_json['annotations'].append(anno_info)
|
101 |
+
ann_id += 1
|
102 |
+
img_id += 1
|
103 |
+
for label in CSLabels.labels:
|
104 |
+
if label.hasInstances and not label.ignoreInEval:
|
105 |
+
cat = dict(id=label.id, name=label.name)
|
106 |
+
out_json['categories'].append(cat)
|
107 |
+
|
108 |
+
if len(out_json['annotations']) == 0:
|
109 |
+
out_json.pop('annotations')
|
110 |
+
|
111 |
+
dump(out_json, out_json_name)
|
112 |
+
return out_json
|
113 |
+
|
114 |
+
|
115 |
+
def parse_args():
|
116 |
+
parser = argparse.ArgumentParser(
|
117 |
+
description='Convert Cityscapes annotations to COCO format')
|
118 |
+
parser.add_argument('cityscapes_path', help='cityscapes data path')
|
119 |
+
parser.add_argument('--img-dir', default='leftImg8bit', type=str)
|
120 |
+
parser.add_argument('--gt-dir', default='gtFine', type=str)
|
121 |
+
parser.add_argument('-o', '--out-dir', help='output path')
|
122 |
+
parser.add_argument(
|
123 |
+
'--nproc', default=1, type=int, help='number of process')
|
124 |
+
args = parser.parse_args()
|
125 |
+
return args
|
126 |
+
|
127 |
+
|
128 |
+
def main():
|
129 |
+
args = parse_args()
|
130 |
+
cityscapes_path = args.cityscapes_path
|
131 |
+
out_dir = args.out_dir if args.out_dir else cityscapes_path
|
132 |
+
mkdir_or_exist(out_dir)
|
133 |
+
|
134 |
+
img_dir = osp.join(cityscapes_path, args.img_dir)
|
135 |
+
gt_dir = osp.join(cityscapes_path, args.gt_dir)
|
136 |
+
|
137 |
+
set_name = dict(
|
138 |
+
train='instancesonly_filtered_gtFine_train.json',
|
139 |
+
val='instancesonly_filtered_gtFine_val.json',
|
140 |
+
test='instancesonly_filtered_gtFine_test.json')
|
141 |
+
|
142 |
+
for split, json_name in set_name.items():
|
143 |
+
print(f'Converting {split} into {json_name}')
|
144 |
+
with Timer(print_tmpl='It took {}s to convert Cityscapes annotation'):
|
145 |
+
files = collect_files(
|
146 |
+
osp.join(img_dir, split), osp.join(gt_dir, split))
|
147 |
+
image_infos = collect_annotations(files, nproc=args.nproc)
|
148 |
+
cvt_annotations(image_infos, osp.join(out_dir, json_name))
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
main()
|
tools/ins_seg/dataset_converters/whu_building_convert.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import mmcv
|
8 |
+
import numpy as np
|
9 |
+
import pycocotools.mask as maskUtils
|
10 |
+
from mmengine.fileio import dump
|
11 |
+
from mmengine.utils import (Timer, mkdir_or_exist, track_parallel_progress,
|
12 |
+
track_progress)
|
13 |
+
|
14 |
+
|
15 |
+
def collect_files(img_dir, gt_dir):
|
16 |
+
files = []
|
17 |
+
img_files = glob.glob(osp.join(img_dir, 'image/*.tif'))
|
18 |
+
for img_file in img_files:
|
19 |
+
segm_file = gt_dir + '/label/' + os.path.basename(img_file)
|
20 |
+
files.append((img_file, segm_file))
|
21 |
+
assert len(files), f'No images found in {img_dir}'
|
22 |
+
print(f'Loaded {len(files)} images from {img_dir}')
|
23 |
+
|
24 |
+
return files
|
25 |
+
|
26 |
+
|
27 |
+
def collect_annotations(files, nproc=1):
|
28 |
+
print('Loading annotation images')
|
29 |
+
if nproc > 1:
|
30 |
+
images = track_parallel_progress(load_img_info, files, nproc=nproc)
|
31 |
+
else:
|
32 |
+
images = track_progress(load_img_info, files)
|
33 |
+
|
34 |
+
return images
|
35 |
+
|
36 |
+
|
37 |
+
def load_img_info(files):
|
38 |
+
img_file, segm_file = files
|
39 |
+
segm_img = mmcv.imread(segm_file, flag='unchanged', backend='cv2')
|
40 |
+
|
41 |
+
num_labels, instances, stats, centroids = cv2.connectedComponentsWithStats(segm_img, connectivity=4)
|
42 |
+
|
43 |
+
anno_info = []
|
44 |
+
for inst_id in range(1, num_labels):
|
45 |
+
category_id = 1
|
46 |
+
mask = np.asarray(instances == inst_id, dtype=np.uint8, order='F')
|
47 |
+
if mask.max() < 1:
|
48 |
+
print(f'Ignore empty instance: {inst_id} in {segm_file}')
|
49 |
+
continue
|
50 |
+
mask_rle = maskUtils.encode(mask[:, :, None])[0]
|
51 |
+
area = maskUtils.area(mask_rle)
|
52 |
+
# convert to COCO style XYWH format
|
53 |
+
bbox = maskUtils.toBbox(mask_rle)
|
54 |
+
|
55 |
+
# for json encoding
|
56 |
+
mask_rle['counts'] = mask_rle['counts'].decode()
|
57 |
+
|
58 |
+
anno = dict(
|
59 |
+
iscrowd=0,
|
60 |
+
category_id=category_id,
|
61 |
+
bbox=bbox.tolist(),
|
62 |
+
area=area.tolist(),
|
63 |
+
segmentation=mask_rle)
|
64 |
+
anno_info.append(anno)
|
65 |
+
video_name = osp.basename(osp.dirname(img_file))
|
66 |
+
img_info = dict(
|
67 |
+
# remove img_prefix for filename
|
68 |
+
file_name=osp.basename(img_file),
|
69 |
+
height=segm_img.shape[0],
|
70 |
+
width=segm_img.shape[1],
|
71 |
+
anno_info=anno_info,
|
72 |
+
segm_file=osp.basename(segm_file))
|
73 |
+
|
74 |
+
return img_info
|
75 |
+
|
76 |
+
|
77 |
+
def cvt_annotations(image_infos, out_json_name):
|
78 |
+
out_json = dict()
|
79 |
+
img_id = 0
|
80 |
+
ann_id = 0
|
81 |
+
out_json['images'] = []
|
82 |
+
out_json['categories'] = []
|
83 |
+
out_json['annotations'] = []
|
84 |
+
for image_info in image_infos:
|
85 |
+
image_info['id'] = img_id
|
86 |
+
anno_infos = image_info.pop('anno_info')
|
87 |
+
out_json['images'].append(image_info)
|
88 |
+
for anno_info in anno_infos:
|
89 |
+
anno_info['image_id'] = img_id
|
90 |
+
anno_info['id'] = ann_id
|
91 |
+
out_json['annotations'].append(anno_info)
|
92 |
+
ann_id += 1
|
93 |
+
img_id += 1
|
94 |
+
|
95 |
+
cat = dict(id=1, name='building')
|
96 |
+
out_json['categories'].append(cat)
|
97 |
+
|
98 |
+
if len(out_json['annotations']) == 0:
|
99 |
+
out_json.pop('annotations')
|
100 |
+
|
101 |
+
dump(out_json, out_json_name)
|
102 |
+
return out_json
|
103 |
+
|
104 |
+
|
105 |
+
def parse_args():
|
106 |
+
parser = argparse.ArgumentParser(
|
107 |
+
description='Convert WHU Building annotations to COCO format')
|
108 |
+
parser.add_argument('--cityscapes_path', default='/Users/kyanchen/datasets/Building/WHU', help='cityscapes data path')
|
109 |
+
parser.add_argument('--img-dir', default='', type=str)
|
110 |
+
parser.add_argument('--gt-dir', default='', type=str)
|
111 |
+
parser.add_argument('-o', '--out-dir', default='/Users/kyanchen/datasets/Building/WHU/annotations', help='output path')
|
112 |
+
parser.add_argument(
|
113 |
+
'--nproc', default=0, type=int, help='number of process')
|
114 |
+
args = parser.parse_args()
|
115 |
+
return args
|
116 |
+
|
117 |
+
|
118 |
+
def main():
|
119 |
+
args = parse_args()
|
120 |
+
cityscapes_path = args.cityscapes_path
|
121 |
+
out_dir = args.out_dir if args.out_dir else cityscapes_path
|
122 |
+
mkdir_or_exist(out_dir)
|
123 |
+
|
124 |
+
img_dir = osp.join(cityscapes_path, args.img_dir)
|
125 |
+
gt_dir = osp.join(cityscapes_path, args.gt_dir)
|
126 |
+
|
127 |
+
set_name = dict(
|
128 |
+
train='WHU_building_train.json',
|
129 |
+
val='WHU_building_val.json',
|
130 |
+
test='WHU_building_test.json'
|
131 |
+
)
|
132 |
+
|
133 |
+
for split, json_name in set_name.items():
|
134 |
+
print(f'Converting {split} into {json_name}')
|
135 |
+
with Timer(print_tmpl='It took {}s to convert Cityscapes annotation'):
|
136 |
+
files = collect_files(
|
137 |
+
osp.join(img_dir, split), osp.join(gt_dir, split))
|
138 |
+
image_infos = collect_annotations(files, nproc=args.nproc)
|
139 |
+
cvt_annotations(image_infos, osp.join(out_dir, json_name))
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == '__main__':
|
143 |
+
main()
|
tools/ins_seg/sam/sam_cls/get_sam_cls_crops.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
sys.path.append(sys.path[0] + '/../../../..')
|
5 |
+
import torch
|
6 |
+
from mmengine import Config, ProgressBar
|
7 |
+
from torchvision.transforms import InterpolationMode
|
8 |
+
from mmpl.registry import DATASETS
|
9 |
+
from tools.ins_seg.sam.sam_cls.segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
10 |
+
import torch
|
11 |
+
from mmdet.evaluation.functional import bbox_overlaps
|
12 |
+
|
13 |
+
|
14 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
15 |
+
BICUBIC = InterpolationMode.BICUBIC
|
16 |
+
|
17 |
+
data_set_name = 'whu'
|
18 |
+
|
19 |
+
config_file = f'configs/ins_seg/samcls_{data_set_name}_config.py'
|
20 |
+
# sam_checkpoint = "pretrain/sam/sam_vit_b_01ec64.pth"
|
21 |
+
# model_type = "vit_b"
|
22 |
+
sam_checkpoint = "pretrain/sam/sam_vit_h_4b8939.pth"
|
23 |
+
model_type = "vit_h"
|
24 |
+
phase = 'val'
|
25 |
+
|
26 |
+
cache_data_root = f'/data/kyanchen/cache_data/ins_seg/sam_cls/{data_set_name}'
|
27 |
+
cache_data_root = os.path.join(cache_data_root, phase)
|
28 |
+
if not os.path.exists(cache_data_root):
|
29 |
+
os.makedirs(cache_data_root)
|
30 |
+
|
31 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
32 |
+
sam.to(device=device)
|
33 |
+
mask_generator = SamAutomaticMaskGenerator(
|
34 |
+
sam,
|
35 |
+
pred_iou_thresh=0.5,
|
36 |
+
box_nms_thresh=0.5,
|
37 |
+
stability_score_thresh=0.6,
|
38 |
+
min_mask_region_area=16,
|
39 |
+
crop_nms_thresh=0.6,
|
40 |
+
)
|
41 |
+
|
42 |
+
cfg = Config.fromfile(config_file)
|
43 |
+
dataset_cfg = cfg.get('datamodule_cfg')
|
44 |
+
dataset_cfg = dataset_cfg.get(f'{phase}_loader').dataset
|
45 |
+
dataset = DATASETS.build(dataset_cfg)
|
46 |
+
class_names = dataset.METAINFO['classes']
|
47 |
+
|
48 |
+
progress_bar = ProgressBar(len(dataset))
|
49 |
+
expand_ratio = 2
|
50 |
+
iou_thresh = 0.2
|
51 |
+
# 1741 2700 3700
|
52 |
+
for index in list(range(len(dataset)))[:500]:
|
53 |
+
print(index)
|
54 |
+
x = dataset[index]
|
55 |
+
img_file = x['data_samples'].img_path
|
56 |
+
gt_bbox = x['data_samples'].gt_instances.bboxes
|
57 |
+
labels = x['data_samples'].gt_instances.labels
|
58 |
+
gt_bbox = gt_bbox.tensor
|
59 |
+
# image = cv2.imread(img_file)
|
60 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
61 |
+
image = x['inputs'].permute(1, 2, 0).numpy()[..., ::-1]
|
62 |
+
masks = mask_generator.generate(image)
|
63 |
+
pred_boxes = [mask['bbox'] for mask in masks]
|
64 |
+
pred_boxes = torch.tensor(pred_boxes, dtype=torch.float32, device=device)
|
65 |
+
pred_boxes[:, 2:] += pred_boxes[:, :2]
|
66 |
+
# # debug to show the image
|
67 |
+
# img = image.copy().astype('uint8')
|
68 |
+
# for gt_box in gt_bbox:
|
69 |
+
# gt_box = gt_box.cpu().numpy().astype(int)
|
70 |
+
# cv2.rectangle(img, (gt_box[0], gt_box[1]), (gt_box[2], gt_box[3]), (0, 255, 0), 2)
|
71 |
+
# for gt_box in pred_boxes:
|
72 |
+
# gt_box = gt_box.cpu().numpy().astype(int)
|
73 |
+
# cv2.rectangle(img, (gt_box[0], gt_box[1]), (gt_box[2], gt_box[3]), (0, 0, 255), 2)
|
74 |
+
# cv2.imshow("image", img.astype('uint8'))
|
75 |
+
# cv2.waitKey(0)
|
76 |
+
|
77 |
+
ious = bbox_overlaps(gt_bbox.cpu().numpy(), pred_boxes.cpu().numpy())
|
78 |
+
idxs = ious.argmax(axis=0)
|
79 |
+
ious = ious[idxs, range(ious.shape[1])]
|
80 |
+
ious_mask = ious > iou_thresh
|
81 |
+
|
82 |
+
for idx, mask in enumerate(masks):
|
83 |
+
# expand box
|
84 |
+
x, y, w, h = mask['bbox']
|
85 |
+
x = x + w // 2
|
86 |
+
y = y + h // 2
|
87 |
+
w = int(w * expand_ratio)
|
88 |
+
h = int(h * expand_ratio)
|
89 |
+
l = int(x - w // 2)
|
90 |
+
t = int(y - h // 2)
|
91 |
+
r = int(x + w // 2)
|
92 |
+
b = int(y + h // 2)
|
93 |
+
l = max(0, l)
|
94 |
+
t = max(0, t)
|
95 |
+
r = min(image.shape[1], r)
|
96 |
+
b = min(image.shape[0], b)
|
97 |
+
if r - l < 16 or b - t < 16:
|
98 |
+
continue
|
99 |
+
|
100 |
+
# blur image
|
101 |
+
blur_image = image.copy()
|
102 |
+
blur_image = cv2.blur(blur_image, (7, 7))
|
103 |
+
blur_image[mask['segmentation']] = image[mask['segmentation']]
|
104 |
+
crop_image = blur_image[t:b, l:r]
|
105 |
+
|
106 |
+
# # # debug to show the image
|
107 |
+
# cv2.imshow("crop_image", crop_image)
|
108 |
+
# seg_mask = image.copy()
|
109 |
+
# seg_mask[mask['segmentation']] = (0, 0, 255)
|
110 |
+
# seg_mask[mask['segmentation']] = 0.5 * seg_mask[mask['segmentation']] + 0.5 * image[mask['segmentation']]
|
111 |
+
# cv2.imshow("image", seg_mask.astype('uint8'))
|
112 |
+
# cv2.waitKey(0)
|
113 |
+
|
114 |
+
label = 255
|
115 |
+
if ious_mask[idx]:
|
116 |
+
label = labels[idxs[idx]].item()
|
117 |
+
cv2.imwrite(os.path.join(cache_data_root, f"{index}_{idx}_crop_{label}.jpg"), crop_image)
|
118 |
+
|
119 |
+
progress_bar.update()
|
120 |
+
|
tools/ins_seg/sam/sam_cls/get_sam_cls_metrics.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
sys.path.append(sys.path[0] + '/../../../..')
|
5 |
+
from mmengine import Config, ProgressBar
|
6 |
+
from mmengine.dataset import Compose
|
7 |
+
from mmengine.structures import InstanceData
|
8 |
+
from torchvision.transforms import InterpolationMode
|
9 |
+
|
10 |
+
from mmpl.registry import DATASETS, MODELS, METRICS
|
11 |
+
from tools.ins_seg.sam.sam_cls.segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
12 |
+
import torch
|
13 |
+
import mmpl.evaluation
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
17 |
+
|
18 |
+
dataset_name = 'whu'
|
19 |
+
seg_model_cfg_file = f'configs/ins_seg/samcls_{dataset_name}_config.py'
|
20 |
+
# sam_checkpoint = "pretrain/sam/sam_vit_b_01ec64.pth"
|
21 |
+
# model_type = "vit_b"
|
22 |
+
sam_checkpoint = "pretrain/sam/sam_vit_h_4b8939.pth"
|
23 |
+
model_type = "vit_h"
|
24 |
+
|
25 |
+
|
26 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
27 |
+
sam.to(device=device)
|
28 |
+
sam.eval()
|
29 |
+
mask_generator = SamAutomaticMaskGenerator(
|
30 |
+
sam,
|
31 |
+
pred_iou_thresh=0.5,
|
32 |
+
box_nms_thresh=0.5,
|
33 |
+
stability_score_thresh=0.6,
|
34 |
+
min_mask_region_area=16,
|
35 |
+
crop_nms_thresh=0.6,
|
36 |
+
)
|
37 |
+
|
38 |
+
# cls model
|
39 |
+
cls_model_cfg_file = f'configs/ins_seg/samcls_res18_{dataset_name}_config.py'
|
40 |
+
cls_ckpt = 'results/whu_ins/E20230607_0/checkpoints/epoch_epoch=59-map_valmulticlassaccuracy_0=0.9516.ckpt'
|
41 |
+
cls_cfg = Config.fromfile(cls_model_cfg_file)
|
42 |
+
cls_model_cfg = cls_cfg.get('model_cfg').whole_model
|
43 |
+
cls_model = MODELS.build(cls_model_cfg)
|
44 |
+
cls_state_dict = torch.load(cls_ckpt, map_location='cpu')['state_dict']
|
45 |
+
cls_state_dict = {k.replace('whole_model.', ''): v for k, v in cls_state_dict.items()}
|
46 |
+
cls_model.load_state_dict(cls_state_dict, strict=True)
|
47 |
+
cls_model.to(device=device)
|
48 |
+
cls_model.eval()
|
49 |
+
|
50 |
+
cls_transform_cfg = cls_cfg.get('datamodule_cfg').val_loader.dataset.pipeline[1:]
|
51 |
+
cls_transforms = Compose(cls_transform_cfg)
|
52 |
+
|
53 |
+
seg_cfg = Config.fromfile(seg_model_cfg_file)
|
54 |
+
seg_dataset_cfg = seg_cfg.get('datamodule_cfg').val_loader.dataset
|
55 |
+
seg_dataset = DATASETS.build(seg_dataset_cfg)
|
56 |
+
val_evaluator_cfg = seg_cfg['evaluator'].val_evaluator
|
57 |
+
val_evaluator = METRICS.build(val_evaluator_cfg)
|
58 |
+
|
59 |
+
val_evaluator.dataset_meta = seg_dataset.metainfo
|
60 |
+
|
61 |
+
|
62 |
+
progress_bar = ProgressBar(len(seg_dataset))
|
63 |
+
expand_ratio = 2
|
64 |
+
for index in range(len(seg_dataset)):
|
65 |
+
seg_data = seg_dataset[index]
|
66 |
+
img_file = seg_data['data_samples'].img_path
|
67 |
+
image = cv2.imread(img_file)
|
68 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
69 |
+
masks = mask_generator.generate(image)
|
70 |
+
|
71 |
+
refined_masks = []
|
72 |
+
for mask in masks:
|
73 |
+
bbbox = mask['bbox']
|
74 |
+
x, y, w, h = bbbox
|
75 |
+
if w < 4 or h < 4:
|
76 |
+
continue
|
77 |
+
refined_masks.append(mask)
|
78 |
+
masks = refined_masks
|
79 |
+
mask_cls = []
|
80 |
+
for mask in masks:
|
81 |
+
# expand box
|
82 |
+
x, y, w, h = mask['bbox']
|
83 |
+
x = x + w // 2
|
84 |
+
y = y + h // 2
|
85 |
+
w = int(w * expand_ratio)
|
86 |
+
h = int(h * expand_ratio)
|
87 |
+
l = int(x - w // 2)
|
88 |
+
t = int(y - h // 2)
|
89 |
+
r = int(x + w // 2)
|
90 |
+
b = int(y + h // 2)
|
91 |
+
l = max(0, l)
|
92 |
+
t = max(0, t)
|
93 |
+
r = min(image.shape[1], r)
|
94 |
+
b = min(image.shape[0], b)
|
95 |
+
|
96 |
+
# blur image
|
97 |
+
blur_image = image.copy()
|
98 |
+
blur_image = cv2.blur(blur_image, (7, 7))
|
99 |
+
blur_image[mask['segmentation']] = image[mask['segmentation']]
|
100 |
+
crop_image = blur_image[t:b, l:r]
|
101 |
+
|
102 |
+
# # debug to show the image
|
103 |
+
# cv2.imshow("crop_image", crop_image)
|
104 |
+
# seg_mask = image.copy()
|
105 |
+
# seg_mask[mask['segmentation']] = (0, 0, 255)
|
106 |
+
# seg_mask[mask['segmentation']] = 0.5 * seg_mask[mask['segmentation']] + 0.5 * image[mask['segmentation']]
|
107 |
+
# cv2.imshow("image", seg_mask.astype('uint8'))
|
108 |
+
# cv2.waitKey(0)
|
109 |
+
results = {
|
110 |
+
'img': crop_image,
|
111 |
+
}
|
112 |
+
transform_crop_data = cls_transforms(results)
|
113 |
+
transform_crop_data['inputs'] = transform_crop_data['inputs'].unsqueeze(0).to(device)
|
114 |
+
transform_crop_data['data_samples'] = [transform_crop_data['data_samples']]
|
115 |
+
data = cls_model.data_preprocessor(transform_crop_data, False)
|
116 |
+
results = cls_model._run_forward(data, mode='predict')
|
117 |
+
|
118 |
+
mask_cls.append(results[0].pred_score)
|
119 |
+
|
120 |
+
mask_pred = torch.stack([torch.from_numpy(mask['segmentation']) for mask in masks], dim=0).to(device=device)
|
121 |
+
bbox_pred = [mask['bbox'] for mask in masks]
|
122 |
+
bbox_pred = torch.tensor(bbox_pred, dtype=torch.float32, device=device)
|
123 |
+
bbox_pred[:, 2:] += bbox_pred[:, :2]
|
124 |
+
|
125 |
+
mask_cls = torch.stack(mask_cls, dim=0)
|
126 |
+
max_per_image = 100
|
127 |
+
num_queries = mask_cls.shape[0]
|
128 |
+
num_classes = mask_cls.shape[-1] - 1
|
129 |
+
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
130 |
+
labels = torch.arange(num_classes, device=mask_cls.device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
|
131 |
+
try:
|
132 |
+
scores_per_image, top_indices = scores.flatten(0, 1).topk(max_per_image, sorted=False)
|
133 |
+
except Exception as e:
|
134 |
+
print(e)
|
135 |
+
continue
|
136 |
+
|
137 |
+
labels_per_image = labels[top_indices]
|
138 |
+
query_indices = top_indices // num_classes
|
139 |
+
mask_pred = mask_pred[query_indices]
|
140 |
+
bbox_pred = bbox_pred[query_indices]
|
141 |
+
|
142 |
+
# # extract things
|
143 |
+
# is_thing = labels_per_image < self.num_things_classes
|
144 |
+
# scores_per_image = scores_per_image[is_thing]
|
145 |
+
# labels_per_image = labels_per_image[is_thing]
|
146 |
+
# mask_pred = mask_pred[is_thing]
|
147 |
+
|
148 |
+
mask_pred_binary = (mask_pred > 0).float()
|
149 |
+
# mask_scores_per_image = (mask_pred.sigmoid() *
|
150 |
+
# mask_pred_binary).flatten(1).sum(1) / (
|
151 |
+
# mask_pred_binary.flatten(1).sum(1) + 1e-6)
|
152 |
+
# det_scores = scores_per_image * mask_scores_per_image
|
153 |
+
det_scores = scores_per_image
|
154 |
+
mask_pred_binary = mask_pred_binary.bool()
|
155 |
+
|
156 |
+
results = InstanceData()
|
157 |
+
results.bboxes = bbox_pred
|
158 |
+
results.labels = labels_per_image
|
159 |
+
results.scores = det_scores
|
160 |
+
results.masks = mask_pred_binary
|
161 |
+
|
162 |
+
data_samples = seg_data['data_samples']
|
163 |
+
data_samples.pred_instances = results
|
164 |
+
|
165 |
+
val_evaluator.update(None, [data_samples])
|
166 |
+
progress_bar.update()
|
167 |
+
|
168 |
+
metrics = val_evaluator.compute()
|
169 |
+
print(metrics)
|
170 |
+
|
tools/ins_seg/sam/sam_cls/segment_anything/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .build_sam import (
|
8 |
+
build_sam,
|
9 |
+
build_sam_vit_h,
|
10 |
+
build_sam_vit_l,
|
11 |
+
build_sam_vit_b,
|
12 |
+
sam_model_registry,
|
13 |
+
)
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
tools/ins_seg/sam/sam_cls/segment_anything/automatic_mask_generator.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
10 |
+
|
11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
12 |
+
|
13 |
+
from .modeling import Sam
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .utils.amg import (
|
16 |
+
MaskData,
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
remove_small_regions,
|
28 |
+
rle_to_mask,
|
29 |
+
uncrop_boxes_xyxy,
|
30 |
+
uncrop_masks,
|
31 |
+
uncrop_points,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class SamAutomaticMaskGenerator:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model: Sam,
|
39 |
+
points_per_side: Optional[int] = 32,
|
40 |
+
points_per_batch: int = 64,
|
41 |
+
pred_iou_thresh: float = 0.88,
|
42 |
+
stability_score_thresh: float = 0.95,
|
43 |
+
stability_score_offset: float = 1.0,
|
44 |
+
box_nms_thresh: float = 0.7,
|
45 |
+
crop_n_layers: int = 0,
|
46 |
+
crop_nms_thresh: float = 0.7,
|
47 |
+
crop_overlap_ratio: float = 512 / 1500,
|
48 |
+
crop_n_points_downscale_factor: int = 1,
|
49 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
50 |
+
min_mask_region_area: int = 0,
|
51 |
+
output_mode: str = "binary_mask",
|
52 |
+
) -> None:
|
53 |
+
"""
|
54 |
+
Using a SAM model, generates masks for the entire image.
|
55 |
+
Generates a grid of point prompts over the image, then filters
|
56 |
+
low quality and duplicate masks. The default settings are chosen
|
57 |
+
for SAM with a ViT-H backbone.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
model (Sam): The SAM model to use for mask prediction.
|
61 |
+
points_per_side (int or None): The number of points to be sampled
|
62 |
+
along one side of the image. The total number of points is
|
63 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
64 |
+
point sampling.
|
65 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
66 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
67 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
68 |
+
model's predicted mask quality.
|
69 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
70 |
+
the stability of the mask under changes to the cutoff used to binarize
|
71 |
+
the model's mask predictions.
|
72 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
73 |
+
calculated the stability score.
|
74 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
75 |
+
suppression to filter duplicate masks.
|
76 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
77 |
+
crops of the image. Sets the number of layers to run, where each
|
78 |
+
layer has 2**i_layer number of image crops.
|
79 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
80 |
+
suppression to filter duplicate masks between different crops.
|
81 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
82 |
+
In the first crop layer, crops will overlap by this fraction of
|
83 |
+
the image length. Later layers with more crops scale down this overlap.
|
84 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
85 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
86 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
87 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
88 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
89 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
90 |
+
to remove disconnected regions and holes in masks with area smaller
|
91 |
+
than min_mask_region_area. Requires opencv.
|
92 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
93 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
94 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
95 |
+
memory.
|
96 |
+
"""
|
97 |
+
|
98 |
+
assert (points_per_side is None) != (
|
99 |
+
point_grids is None
|
100 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
101 |
+
if points_per_side is not None:
|
102 |
+
self.point_grids = build_all_layer_point_grids(
|
103 |
+
points_per_side,
|
104 |
+
crop_n_layers,
|
105 |
+
crop_n_points_downscale_factor,
|
106 |
+
)
|
107 |
+
elif point_grids is not None:
|
108 |
+
self.point_grids = point_grids
|
109 |
+
else:
|
110 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
111 |
+
|
112 |
+
assert output_mode in [
|
113 |
+
"binary_mask",
|
114 |
+
"uncompressed_rle",
|
115 |
+
"coco_rle",
|
116 |
+
], f"Unknown output_mode {output_mode}."
|
117 |
+
if output_mode == "coco_rle":
|
118 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
119 |
+
|
120 |
+
if min_mask_region_area > 0:
|
121 |
+
import cv2 # type: ignore # noqa: F401
|
122 |
+
|
123 |
+
self.predictor = SamPredictor(model)
|
124 |
+
self.points_per_batch = points_per_batch
|
125 |
+
self.pred_iou_thresh = pred_iou_thresh
|
126 |
+
self.stability_score_thresh = stability_score_thresh
|
127 |
+
self.stability_score_offset = stability_score_offset
|
128 |
+
self.box_nms_thresh = box_nms_thresh
|
129 |
+
self.crop_n_layers = crop_n_layers
|
130 |
+
self.crop_nms_thresh = crop_nms_thresh
|
131 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
132 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
133 |
+
self.min_mask_region_area = min_mask_region_area
|
134 |
+
self.output_mode = output_mode
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
138 |
+
"""
|
139 |
+
Generates masks for the given image.
|
140 |
+
|
141 |
+
Arguments:
|
142 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
146 |
+
a dict containing the following keys:
|
147 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
148 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
149 |
+
is a dictionary containing the RLE.
|
150 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
151 |
+
area (int): The area in pixels of the mask.
|
152 |
+
predicted_iou (float): The model's own prediction of the mask's
|
153 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
154 |
+
point_coords (list(list(float))): The point coordinates input
|
155 |
+
to the model to generate this mask.
|
156 |
+
stability_score (float): A measure of the mask's quality. This
|
157 |
+
is filtered on using the stability_score_thresh parameter.
|
158 |
+
crop_box (list(float)): The crop of the image used to generate
|
159 |
+
the mask, given in XYWH format.
|
160 |
+
"""
|
161 |
+
|
162 |
+
# Generate masks
|
163 |
+
mask_data = self._generate_masks(image)
|
164 |
+
|
165 |
+
# Filter small disconnected regions and holes in masks
|
166 |
+
if self.min_mask_region_area > 0:
|
167 |
+
mask_data = self.postprocess_small_regions(
|
168 |
+
mask_data,
|
169 |
+
self.min_mask_region_area,
|
170 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
171 |
+
)
|
172 |
+
|
173 |
+
# Encode masks
|
174 |
+
if self.output_mode == "coco_rle":
|
175 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
176 |
+
elif self.output_mode == "binary_mask":
|
177 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
178 |
+
else:
|
179 |
+
mask_data["segmentations"] = mask_data["rles"]
|
180 |
+
|
181 |
+
# Write mask records
|
182 |
+
curr_anns = []
|
183 |
+
for idx in range(len(mask_data["segmentations"])):
|
184 |
+
ann = {
|
185 |
+
"segmentation": mask_data["segmentations"][idx],
|
186 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
187 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
188 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
189 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
190 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
191 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
192 |
+
}
|
193 |
+
curr_anns.append(ann)
|
194 |
+
|
195 |
+
return curr_anns
|
196 |
+
|
197 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
198 |
+
orig_size = image.shape[:2]
|
199 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
200 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
201 |
+
)
|
202 |
+
|
203 |
+
# Iterate over image crops
|
204 |
+
data = MaskData()
|
205 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
206 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
207 |
+
data.cat(crop_data)
|
208 |
+
|
209 |
+
# Remove duplicate masks between crops
|
210 |
+
if len(crop_boxes) > 1:
|
211 |
+
# Prefer masks from smaller crops
|
212 |
+
scores = 1 / box_area(data["crop_boxes"])
|
213 |
+
scores = scores.to(data["boxes"].device)
|
214 |
+
keep_by_nms = batched_nms(
|
215 |
+
data["boxes"].float(),
|
216 |
+
scores,
|
217 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
218 |
+
iou_threshold=self.crop_nms_thresh,
|
219 |
+
)
|
220 |
+
data.filter(keep_by_nms)
|
221 |
+
|
222 |
+
data.to_numpy()
|
223 |
+
return data
|
224 |
+
|
225 |
+
def _process_crop(
|
226 |
+
self,
|
227 |
+
image: np.ndarray,
|
228 |
+
crop_box: List[int],
|
229 |
+
crop_layer_idx: int,
|
230 |
+
orig_size: Tuple[int, ...],
|
231 |
+
) -> MaskData:
|
232 |
+
# Crop the image and calculate embeddings
|
233 |
+
x0, y0, x1, y1 = crop_box
|
234 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
235 |
+
cropped_im_size = cropped_im.shape[:2]
|
236 |
+
self.predictor.set_image(cropped_im)
|
237 |
+
|
238 |
+
# Get points for this crop
|
239 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
240 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
241 |
+
|
242 |
+
# Generate masks for this crop in batches
|
243 |
+
data = MaskData()
|
244 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
245 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
246 |
+
data.cat(batch_data)
|
247 |
+
del batch_data
|
248 |
+
self.predictor.reset_image()
|
249 |
+
|
250 |
+
# Remove duplicates within this crop.
|
251 |
+
keep_by_nms = batched_nms(
|
252 |
+
data["boxes"].float(),
|
253 |
+
data["iou_preds"],
|
254 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
255 |
+
iou_threshold=self.box_nms_thresh,
|
256 |
+
)
|
257 |
+
data.filter(keep_by_nms)
|
258 |
+
|
259 |
+
# Return to the original image frame
|
260 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
261 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
262 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
263 |
+
|
264 |
+
return data
|
265 |
+
|
266 |
+
def _process_batch(
|
267 |
+
self,
|
268 |
+
points: np.ndarray,
|
269 |
+
im_size: Tuple[int, ...],
|
270 |
+
crop_box: List[int],
|
271 |
+
orig_size: Tuple[int, ...],
|
272 |
+
) -> MaskData:
|
273 |
+
orig_h, orig_w = orig_size
|
274 |
+
|
275 |
+
# Run model on this batch
|
276 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
277 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
278 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
279 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
280 |
+
in_points[:, None, :],
|
281 |
+
in_labels[:, None],
|
282 |
+
multimask_output=True,
|
283 |
+
return_logits=True,
|
284 |
+
)
|
285 |
+
|
286 |
+
# Serialize predictions and store in MaskData
|
287 |
+
data = MaskData(
|
288 |
+
masks=masks.flatten(0, 1),
|
289 |
+
iou_preds=iou_preds.flatten(0, 1),
|
290 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
291 |
+
)
|
292 |
+
del masks
|
293 |
+
|
294 |
+
# Filter by predicted IoU
|
295 |
+
if self.pred_iou_thresh > 0.0:
|
296 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
297 |
+
data.filter(keep_mask)
|
298 |
+
|
299 |
+
# Calculate stability score
|
300 |
+
data["stability_score"] = calculate_stability_score(
|
301 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
302 |
+
)
|
303 |
+
if self.stability_score_thresh > 0.0:
|
304 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
305 |
+
data.filter(keep_mask)
|
306 |
+
|
307 |
+
# Threshold masks and calculate boxes
|
308 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
309 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
310 |
+
|
311 |
+
# Filter boxes that touch crop boundaries
|
312 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
313 |
+
if not torch.all(keep_mask):
|
314 |
+
data.filter(keep_mask)
|
315 |
+
|
316 |
+
# Compress to RLE
|
317 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
318 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
319 |
+
del data["masks"]
|
320 |
+
|
321 |
+
return data
|
322 |
+
|
323 |
+
@staticmethod
|
324 |
+
def postprocess_small_regions(
|
325 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
326 |
+
) -> MaskData:
|
327 |
+
"""
|
328 |
+
Removes small disconnected regions and holes in masks, then reruns
|
329 |
+
box NMS to remove any new duplicates.
|
330 |
+
|
331 |
+
Edits mask_data in place.
|
332 |
+
|
333 |
+
Requires open-cv as a dependency.
|
334 |
+
"""
|
335 |
+
if len(mask_data["rles"]) == 0:
|
336 |
+
return mask_data
|
337 |
+
|
338 |
+
# Filter small disconnected regions and holes
|
339 |
+
new_masks = []
|
340 |
+
scores = []
|
341 |
+
for rle in mask_data["rles"]:
|
342 |
+
mask = rle_to_mask(rle)
|
343 |
+
|
344 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
345 |
+
unchanged = not changed
|
346 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
347 |
+
unchanged = unchanged and not changed
|
348 |
+
|
349 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
350 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
351 |
+
# so NMS will prefer ones that didn't need postprocessing
|
352 |
+
scores.append(float(unchanged))
|
353 |
+
|
354 |
+
# Recalculate boxes and remove any new duplicates
|
355 |
+
masks = torch.cat(new_masks, dim=0)
|
356 |
+
boxes = batched_mask_to_box(masks)
|
357 |
+
keep_by_nms = batched_nms(
|
358 |
+
boxes.float(),
|
359 |
+
torch.as_tensor(scores),
|
360 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
361 |
+
iou_threshold=nms_thresh,
|
362 |
+
)
|
363 |
+
|
364 |
+
# Only recalculate RLEs for masks that have changed
|
365 |
+
for i_mask in keep_by_nms:
|
366 |
+
if scores[i_mask] == 0.0:
|
367 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
368 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
369 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
370 |
+
mask_data.filter(keep_by_nms)
|
371 |
+
|
372 |
+
return mask_data
|
tools/ins_seg/sam/sam_cls/segment_anything/build_sam.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
12 |
+
|
13 |
+
|
14 |
+
def build_sam_vit_h(checkpoint=None):
|
15 |
+
return _build_sam(
|
16 |
+
encoder_embed_dim=1280,
|
17 |
+
encoder_depth=32,
|
18 |
+
encoder_num_heads=16,
|
19 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
20 |
+
checkpoint=checkpoint,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
build_sam = build_sam_vit_h
|
25 |
+
|
26 |
+
|
27 |
+
def build_sam_vit_l(checkpoint=None):
|
28 |
+
return _build_sam(
|
29 |
+
encoder_embed_dim=1024,
|
30 |
+
encoder_depth=24,
|
31 |
+
encoder_num_heads=16,
|
32 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
33 |
+
checkpoint=checkpoint,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def build_sam_vit_b(checkpoint=None):
|
38 |
+
return _build_sam(
|
39 |
+
encoder_embed_dim=768,
|
40 |
+
encoder_depth=12,
|
41 |
+
encoder_num_heads=12,
|
42 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
43 |
+
checkpoint=checkpoint,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
sam_model_registry = {
|
48 |
+
"default": build_sam_vit_h,
|
49 |
+
"vit_h": build_sam_vit_h,
|
50 |
+
"vit_l": build_sam_vit_l,
|
51 |
+
"vit_b": build_sam_vit_b,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def _build_sam(
|
56 |
+
encoder_embed_dim,
|
57 |
+
encoder_depth,
|
58 |
+
encoder_num_heads,
|
59 |
+
encoder_global_attn_indexes,
|
60 |
+
checkpoint=None,
|
61 |
+
):
|
62 |
+
prompt_embed_dim = 256
|
63 |
+
image_size = 1024
|
64 |
+
vit_patch_size = 16
|
65 |
+
image_embedding_size = image_size // vit_patch_size
|
66 |
+
sam = Sam(
|
67 |
+
image_encoder=ImageEncoderViT(
|
68 |
+
depth=encoder_depth,
|
69 |
+
embed_dim=encoder_embed_dim,
|
70 |
+
img_size=image_size,
|
71 |
+
mlp_ratio=4,
|
72 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
73 |
+
num_heads=encoder_num_heads,
|
74 |
+
patch_size=vit_patch_size,
|
75 |
+
qkv_bias=True,
|
76 |
+
use_rel_pos=True,
|
77 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
78 |
+
window_size=14,
|
79 |
+
out_chans=prompt_embed_dim,
|
80 |
+
),
|
81 |
+
prompt_encoder=PromptEncoder(
|
82 |
+
embed_dim=prompt_embed_dim,
|
83 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
84 |
+
input_image_size=(image_size, image_size),
|
85 |
+
mask_in_chans=16,
|
86 |
+
),
|
87 |
+
mask_decoder=MaskDecoder(
|
88 |
+
num_multimask_outputs=3,
|
89 |
+
transformer=TwoWayTransformer(
|
90 |
+
depth=2,
|
91 |
+
embedding_dim=prompt_embed_dim,
|
92 |
+
mlp_dim=2048,
|
93 |
+
num_heads=8,
|
94 |
+
),
|
95 |
+
transformer_dim=prompt_embed_dim,
|
96 |
+
iou_head_depth=3,
|
97 |
+
iou_head_hidden_dim=256,
|
98 |
+
),
|
99 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
100 |
+
pixel_std=[58.395, 57.12, 57.375],
|
101 |
+
)
|
102 |
+
sam.eval()
|
103 |
+
if checkpoint is not None:
|
104 |
+
with open(checkpoint, "rb") as f:
|
105 |
+
state_dict = torch.load(f)
|
106 |
+
sam.load_state_dict(state_dict)
|
107 |
+
return sam
|
tools/ins_seg/sam/sam_cls/segment_anything/modeling/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .sam import Sam
|
8 |
+
from .image_encoder import ImageEncoderViT
|
9 |
+
from .mask_decoder import MaskDecoder
|
10 |
+
from .prompt_encoder import PromptEncoder
|
11 |
+
from .transformer import TwoWayTransformer
|
tools/ins_seg/sam/sam_cls/segment_anything/modeling/common.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from typing import Type
|
11 |
+
|
12 |
+
|
13 |
+
class MLPBlock(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
embedding_dim: int,
|
17 |
+
mlp_dim: int,
|
18 |
+
act: Type[nn.Module] = nn.GELU,
|
19 |
+
) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
22 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
23 |
+
self.act = act()
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
26 |
+
return self.lin2(self.act(self.lin1(x)))
|
27 |
+
|
28 |
+
|
29 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
30 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
31 |
+
class LayerNorm2d(nn.Module):
|
32 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
35 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
36 |
+
self.eps = eps
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39 |
+
u = x.mean(1, keepdim=True)
|
40 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
41 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
42 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
43 |
+
return x
|
tools/ins_seg/sam/sam_cls/segment_anything/modeling/image_encoder.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from typing import Optional, Tuple, Type
|
12 |
+
|
13 |
+
from .common import LayerNorm2d, MLPBlock
|
14 |
+
|
15 |
+
|
16 |
+
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
17 |
+
class ImageEncoderViT(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
img_size: int = 1024,
|
21 |
+
patch_size: int = 16,
|
22 |
+
in_chans: int = 3,
|
23 |
+
embed_dim: int = 768,
|
24 |
+
depth: int = 12,
|
25 |
+
num_heads: int = 12,
|
26 |
+
mlp_ratio: float = 4.0,
|
27 |
+
out_chans: int = 256,
|
28 |
+
qkv_bias: bool = True,
|
29 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
30 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
31 |
+
use_abs_pos: bool = True,
|
32 |
+
use_rel_pos: bool = False,
|
33 |
+
rel_pos_zero_init: bool = True,
|
34 |
+
window_size: int = 0,
|
35 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
36 |
+
) -> None:
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
img_size (int): Input image size.
|
40 |
+
patch_size (int): Patch size.
|
41 |
+
in_chans (int): Number of input image channels.
|
42 |
+
embed_dim (int): Patch embedding dimension.
|
43 |
+
depth (int): Depth of ViT.
|
44 |
+
num_heads (int): Number of attention heads in each ViT block.
|
45 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
46 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
47 |
+
norm_layer (nn.Module): Normalization layer.
|
48 |
+
act_layer (nn.Module): Activation layer.
|
49 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
50 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
51 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
52 |
+
window_size (int): Window size for window attention blocks.
|
53 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
54 |
+
"""
|
55 |
+
super().__init__()
|
56 |
+
self.img_size = img_size
|
57 |
+
|
58 |
+
self.patch_embed = PatchEmbed(
|
59 |
+
kernel_size=(patch_size, patch_size),
|
60 |
+
stride=(patch_size, patch_size),
|
61 |
+
in_chans=in_chans,
|
62 |
+
embed_dim=embed_dim,
|
63 |
+
)
|
64 |
+
|
65 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
66 |
+
if use_abs_pos:
|
67 |
+
# Initialize absolute positional embedding with pretrain image size.
|
68 |
+
self.pos_embed = nn.Parameter(
|
69 |
+
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
70 |
+
)
|
71 |
+
|
72 |
+
self.blocks = nn.ModuleList()
|
73 |
+
for i in range(depth):
|
74 |
+
block = Block(
|
75 |
+
dim=embed_dim,
|
76 |
+
num_heads=num_heads,
|
77 |
+
mlp_ratio=mlp_ratio,
|
78 |
+
qkv_bias=qkv_bias,
|
79 |
+
norm_layer=norm_layer,
|
80 |
+
act_layer=act_layer,
|
81 |
+
use_rel_pos=use_rel_pos,
|
82 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
83 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
84 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
85 |
+
)
|
86 |
+
self.blocks.append(block)
|
87 |
+
|
88 |
+
self.neck = nn.Sequential(
|
89 |
+
nn.Conv2d(
|
90 |
+
embed_dim,
|
91 |
+
out_chans,
|
92 |
+
kernel_size=1,
|
93 |
+
bias=False,
|
94 |
+
),
|
95 |
+
LayerNorm2d(out_chans),
|
96 |
+
nn.Conv2d(
|
97 |
+
out_chans,
|
98 |
+
out_chans,
|
99 |
+
kernel_size=3,
|
100 |
+
padding=1,
|
101 |
+
bias=False,
|
102 |
+
),
|
103 |
+
LayerNorm2d(out_chans),
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
107 |
+
x = self.patch_embed(x)
|
108 |
+
if self.pos_embed is not None:
|
109 |
+
x = x + self.pos_embed
|
110 |
+
|
111 |
+
for blk in self.blocks:
|
112 |
+
x = blk(x)
|
113 |
+
|
114 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
115 |
+
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class Block(nn.Module):
|
120 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
dim: int,
|
125 |
+
num_heads: int,
|
126 |
+
mlp_ratio: float = 4.0,
|
127 |
+
qkv_bias: bool = True,
|
128 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
129 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
130 |
+
use_rel_pos: bool = False,
|
131 |
+
rel_pos_zero_init: bool = True,
|
132 |
+
window_size: int = 0,
|
133 |
+
input_size: Optional[Tuple[int, int]] = None,
|
134 |
+
) -> None:
|
135 |
+
"""
|
136 |
+
Args:
|
137 |
+
dim (int): Number of input channels.
|
138 |
+
num_heads (int): Number of attention heads in each ViT block.
|
139 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
140 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
141 |
+
norm_layer (nn.Module): Normalization layer.
|
142 |
+
act_layer (nn.Module): Activation layer.
|
143 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
144 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
145 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
146 |
+
use global attention.
|
147 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
148 |
+
positional parameter size.
|
149 |
+
"""
|
150 |
+
super().__init__()
|
151 |
+
self.norm1 = norm_layer(dim)
|
152 |
+
self.attn = Attention(
|
153 |
+
dim,
|
154 |
+
num_heads=num_heads,
|
155 |
+
qkv_bias=qkv_bias,
|
156 |
+
use_rel_pos=use_rel_pos,
|
157 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
158 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
159 |
+
)
|
160 |
+
|
161 |
+
self.norm2 = norm_layer(dim)
|
162 |
+
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
163 |
+
|
164 |
+
self.window_size = window_size
|
165 |
+
|
166 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
167 |
+
shortcut = x
|
168 |
+
x = self.norm1(x)
|
169 |
+
# Window partition
|
170 |
+
if self.window_size > 0:
|
171 |
+
H, W = x.shape[1], x.shape[2]
|
172 |
+
x, pad_hw = window_partition(x, self.window_size)
|
173 |
+
|
174 |
+
x = self.attn(x)
|
175 |
+
# Reverse window partition
|
176 |
+
if self.window_size > 0:
|
177 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
178 |
+
|
179 |
+
x = shortcut + x
|
180 |
+
x = x + self.mlp(self.norm2(x))
|
181 |
+
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
class Attention(nn.Module):
|
186 |
+
"""Multi-head Attention block with relative position embeddings."""
|
187 |
+
|
188 |
+
def __init__(
|
189 |
+
self,
|
190 |
+
dim: int,
|
191 |
+
num_heads: int = 8,
|
192 |
+
qkv_bias: bool = True,
|
193 |
+
use_rel_pos: bool = False,
|
194 |
+
rel_pos_zero_init: bool = True,
|
195 |
+
input_size: Optional[Tuple[int, int]] = None,
|
196 |
+
) -> None:
|
197 |
+
"""
|
198 |
+
Args:
|
199 |
+
dim (int): Number of input channels.
|
200 |
+
num_heads (int): Number of attention heads.
|
201 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
202 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
203 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
204 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
205 |
+
positional parameter size.
|
206 |
+
"""
|
207 |
+
super().__init__()
|
208 |
+
self.num_heads = num_heads
|
209 |
+
head_dim = dim // num_heads
|
210 |
+
self.scale = head_dim**-0.5
|
211 |
+
|
212 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
213 |
+
self.proj = nn.Linear(dim, dim)
|
214 |
+
|
215 |
+
self.use_rel_pos = use_rel_pos
|
216 |
+
if self.use_rel_pos:
|
217 |
+
assert (
|
218 |
+
input_size is not None
|
219 |
+
), "Input size must be provided if using relative positional encoding."
|
220 |
+
# initialize relative positional embeddings
|
221 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
222 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
223 |
+
|
224 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
225 |
+
B, H, W, _ = x.shape
|
226 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
227 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
228 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
229 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
230 |
+
|
231 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
232 |
+
|
233 |
+
if self.use_rel_pos:
|
234 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
235 |
+
|
236 |
+
attn = attn.softmax(dim=-1)
|
237 |
+
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
238 |
+
x = self.proj(x)
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
244 |
+
"""
|
245 |
+
Partition into non-overlapping windows with padding if needed.
|
246 |
+
Args:
|
247 |
+
x (tensor): input tokens with [B, H, W, C].
|
248 |
+
window_size (int): window size.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
252 |
+
(Hp, Wp): padded height and width before partition
|
253 |
+
"""
|
254 |
+
B, H, W, C = x.shape
|
255 |
+
|
256 |
+
pad_h = (window_size - H % window_size) % window_size
|
257 |
+
pad_w = (window_size - W % window_size) % window_size
|
258 |
+
if pad_h > 0 or pad_w > 0:
|
259 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
260 |
+
Hp, Wp = H + pad_h, W + pad_w
|
261 |
+
|
262 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
263 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
264 |
+
return windows, (Hp, Wp)
|
265 |
+
|
266 |
+
|
267 |
+
def window_unpartition(
|
268 |
+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
269 |
+
) -> torch.Tensor:
|
270 |
+
"""
|
271 |
+
Window unpartition into original sequences and removing padding.
|
272 |
+
Args:
|
273 |
+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
274 |
+
window_size (int): window size.
|
275 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
276 |
+
hw (Tuple): original height and width (H, W) before padding.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
x: unpartitioned sequences with [B, H, W, C].
|
280 |
+
"""
|
281 |
+
Hp, Wp = pad_hw
|
282 |
+
H, W = hw
|
283 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
284 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
285 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
286 |
+
|
287 |
+
if Hp > H or Wp > W:
|
288 |
+
x = x[:, :H, :W, :].contiguous()
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
293 |
+
"""
|
294 |
+
Get relative positional embeddings according to the relative positions of
|
295 |
+
query and key sizes.
|
296 |
+
Args:
|
297 |
+
q_size (int): size of query q.
|
298 |
+
k_size (int): size of key k.
|
299 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
Extracted positional embeddings according to relative positions.
|
303 |
+
"""
|
304 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
305 |
+
# Interpolate rel pos if needed.
|
306 |
+
if rel_pos.shape[0] != max_rel_dist:
|
307 |
+
# Interpolate rel pos.
|
308 |
+
rel_pos_resized = F.interpolate(
|
309 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
310 |
+
size=max_rel_dist,
|
311 |
+
mode="linear",
|
312 |
+
)
|
313 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
314 |
+
else:
|
315 |
+
rel_pos_resized = rel_pos
|
316 |
+
|
317 |
+
# Scale the coords with short length if shapes for q and k are different.
|
318 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
319 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
320 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
321 |
+
|
322 |
+
return rel_pos_resized[relative_coords.long()]
|
323 |
+
|
324 |
+
|
325 |
+
def add_decomposed_rel_pos(
|
326 |
+
attn: torch.Tensor,
|
327 |
+
q: torch.Tensor,
|
328 |
+
rel_pos_h: torch.Tensor,
|
329 |
+
rel_pos_w: torch.Tensor,
|
330 |
+
q_size: Tuple[int, int],
|
331 |
+
k_size: Tuple[int, int],
|
332 |
+
) -> torch.Tensor:
|
333 |
+
"""
|
334 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
335 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
336 |
+
Args:
|
337 |
+
attn (Tensor): attention map.
|
338 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
339 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
340 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
341 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
342 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
346 |
+
"""
|
347 |
+
q_h, q_w = q_size
|
348 |
+
k_h, k_w = k_size
|
349 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
350 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
351 |
+
|
352 |
+
B, _, dim = q.shape
|
353 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
354 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
355 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
356 |
+
|
357 |
+
attn = (
|
358 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
359 |
+
).view(B, q_h * q_w, k_h * k_w)
|
360 |
+
|
361 |
+
return attn
|
362 |
+
|
363 |
+
|
364 |
+
class PatchEmbed(nn.Module):
|
365 |
+
"""
|
366 |
+
Image to Patch Embedding.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
kernel_size: Tuple[int, int] = (16, 16),
|
372 |
+
stride: Tuple[int, int] = (16, 16),
|
373 |
+
padding: Tuple[int, int] = (0, 0),
|
374 |
+
in_chans: int = 3,
|
375 |
+
embed_dim: int = 768,
|
376 |
+
) -> None:
|
377 |
+
"""
|
378 |
+
Args:
|
379 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
380 |
+
stride (Tuple): stride of the projection layer.
|
381 |
+
padding (Tuple): padding size of the projection layer.
|
382 |
+
in_chans (int): Number of input image channels.
|
383 |
+
embed_dim (int): Patch embedding dimension.
|
384 |
+
"""
|
385 |
+
super().__init__()
|
386 |
+
|
387 |
+
self.proj = nn.Conv2d(
|
388 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
389 |
+
)
|
390 |
+
|
391 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
392 |
+
x = self.proj(x)
|
393 |
+
# B C H W -> B H W C
|
394 |
+
x = x.permute(0, 2, 3, 1)
|
395 |
+
return x
|
tools/ins_seg/sam/sam_cls/segment_anything/modeling/mask_decoder.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from typing import List, Tuple, Type
|
12 |
+
|
13 |
+
from .common import LayerNorm2d
|
14 |
+
|
15 |
+
|
16 |
+
class MaskDecoder(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
*,
|
20 |
+
transformer_dim: int,
|
21 |
+
transformer: nn.Module,
|
22 |
+
num_multimask_outputs: int = 3,
|
23 |
+
activation: Type[nn.Module] = nn.GELU,
|
24 |
+
iou_head_depth: int = 3,
|
25 |
+
iou_head_hidden_dim: int = 256,
|
26 |
+
) -> None:
|
27 |
+
"""
|
28 |
+
Predicts masks given an image and prompt embeddings, using a
|
29 |
+
transformer architecture.
|
30 |
+
|
31 |
+
Arguments:
|
32 |
+
transformer_dim (int): the channel dimension of the transformer
|
33 |
+
transformer (nn.Module): the transformer used to predict masks
|
34 |
+
num_multimask_outputs (int): the number of masks to predict
|
35 |
+
when disambiguating masks
|
36 |
+
activation (nn.Module): the type of activation to use when
|
37 |
+
upscaling masks
|
38 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
39 |
+
mask quality
|
40 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
41 |
+
used to predict mask quality
|
42 |
+
"""
|
43 |
+
super().__init__()
|
44 |
+
self.transformer_dim = transformer_dim
|
45 |
+
self.transformer = transformer
|
46 |
+
|
47 |
+
self.num_multimask_outputs = num_multimask_outputs
|
48 |
+
|
49 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
50 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
51 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
52 |
+
|
53 |
+
self.output_upscaling = nn.Sequential(
|
54 |
+
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
55 |
+
LayerNorm2d(transformer_dim // 4),
|
56 |
+
activation(),
|
57 |
+
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
58 |
+
activation(),
|
59 |
+
)
|
60 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
61 |
+
[
|
62 |
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
63 |
+
for i in range(self.num_mask_tokens)
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
self.iou_prediction_head = MLP(
|
68 |
+
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
image_embeddings: torch.Tensor,
|
74 |
+
image_pe: torch.Tensor,
|
75 |
+
sparse_prompt_embeddings: torch.Tensor,
|
76 |
+
dense_prompt_embeddings: torch.Tensor,
|
77 |
+
multimask_output: bool,
|
78 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
79 |
+
"""
|
80 |
+
Predict masks given image and prompt embeddings.
|
81 |
+
|
82 |
+
Arguments:
|
83 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
84 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
85 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
86 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
87 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
88 |
+
mask.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
torch.Tensor: batched predicted masks
|
92 |
+
torch.Tensor: batched predictions of mask quality
|
93 |
+
"""
|
94 |
+
masks, iou_pred = self.predict_masks(
|
95 |
+
image_embeddings=image_embeddings,
|
96 |
+
image_pe=image_pe,
|
97 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
98 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
99 |
+
)
|
100 |
+
|
101 |
+
# Select the correct mask or masks for output
|
102 |
+
if multimask_output:
|
103 |
+
mask_slice = slice(1, None)
|
104 |
+
else:
|
105 |
+
mask_slice = slice(0, 1)
|
106 |
+
masks = masks[:, mask_slice, :, :]
|
107 |
+
iou_pred = iou_pred[:, mask_slice]
|
108 |
+
|
109 |
+
# Prepare output
|
110 |
+
return masks, iou_pred
|
111 |
+
|
112 |
+
def predict_masks(
|
113 |
+
self,
|
114 |
+
image_embeddings: torch.Tensor,
|
115 |
+
image_pe: torch.Tensor,
|
116 |
+
sparse_prompt_embeddings: torch.Tensor,
|
117 |
+
dense_prompt_embeddings: torch.Tensor,
|
118 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
119 |
+
"""Predicts masks. See 'forward' for more details."""
|
120 |
+
# Concatenate output tokens
|
121 |
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
122 |
+
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
123 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
124 |
+
|
125 |
+
# Expand per-image data in batch direction to be per-mask
|
126 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
127 |
+
src = src + dense_prompt_embeddings
|
128 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
129 |
+
b, c, h, w = src.shape
|
130 |
+
|
131 |
+
# Run the transformer
|
132 |
+
hs, src = self.transformer(src, pos_src, tokens)
|
133 |
+
iou_token_out = hs[:, 0, :]
|
134 |
+
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
135 |
+
|
136 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
137 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
138 |
+
upscaled_embedding = self.output_upscaling(src)
|
139 |
+
hyper_in_list: List[torch.Tensor] = []
|
140 |
+
for i in range(self.num_mask_tokens):
|
141 |
+
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
142 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
143 |
+
b, c, h, w = upscaled_embedding.shape
|
144 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
145 |
+
|
146 |
+
# Generate mask quality predictions
|
147 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
148 |
+
|
149 |
+
return masks, iou_pred
|
150 |
+
|
151 |
+
|
152 |
+
# Lightly adapted from
|
153 |
+
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
154 |
+
class MLP(nn.Module):
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
input_dim: int,
|
158 |
+
hidden_dim: int,
|
159 |
+
output_dim: int,
|
160 |
+
num_layers: int,
|
161 |
+
sigmoid_output: bool = False,
|
162 |
+
) -> None:
|
163 |
+
super().__init__()
|
164 |
+
self.num_layers = num_layers
|
165 |
+
h = [hidden_dim] * (num_layers - 1)
|
166 |
+
self.layers = nn.ModuleList(
|
167 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
168 |
+
)
|
169 |
+
self.sigmoid_output = sigmoid_output
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
for i, layer in enumerate(self.layers):
|
173 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
174 |
+
if self.sigmoid_output:
|
175 |
+
x = F.sigmoid(x)
|
176 |
+
return x
|
tools/ins_seg/sam/sam_cls/segment_anything/modeling/prompt_encoder.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from typing import Any, Optional, Tuple, Type
|
12 |
+
|
13 |
+
from .common import LayerNorm2d
|
14 |
+
|
15 |
+
|
16 |
+
class PromptEncoder(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
embed_dim: int,
|
20 |
+
image_embedding_size: Tuple[int, int],
|
21 |
+
input_image_size: Tuple[int, int],
|
22 |
+
mask_in_chans: int,
|
23 |
+
activation: Type[nn.Module] = nn.GELU,
|
24 |
+
) -> None:
|
25 |
+
"""
|
26 |
+
Encodes prompts for input to SAM's mask decoder.
|
27 |
+
|
28 |
+
Arguments:
|
29 |
+
embed_dim (int): The prompts' embedding dimension
|
30 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
31 |
+
image embedding, as (H, W).
|
32 |
+
input_image_size (int): The padded size of the image as input
|
33 |
+
to the image encoder, as (H, W).
|
34 |
+
mask_in_chans (int): The number of hidden channels used for
|
35 |
+
encoding input masks.
|
36 |
+
activation (nn.Module): The activation to use when encoding
|
37 |
+
input masks.
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
self.embed_dim = embed_dim
|
41 |
+
self.input_image_size = input_image_size
|
42 |
+
self.image_embedding_size = image_embedding_size
|
43 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
44 |
+
|
45 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
46 |
+
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
47 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
48 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
49 |
+
|
50 |
+
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
|
51 |
+
self.mask_downscaling = nn.Sequential(
|
52 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
53 |
+
LayerNorm2d(mask_in_chans // 4),
|
54 |
+
activation(),
|
55 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
56 |
+
LayerNorm2d(mask_in_chans),
|
57 |
+
activation(),
|
58 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
59 |
+
)
|
60 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
61 |
+
|
62 |
+
def get_dense_pe(self) -> torch.Tensor:
|
63 |
+
"""
|
64 |
+
Returns the positional encoding used to encode point prompts,
|
65 |
+
applied to a dense set of points the shape of the image encoding.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor: Positional encoding with shape
|
69 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
70 |
+
"""
|
71 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
72 |
+
|
73 |
+
def _embed_points(
|
74 |
+
self,
|
75 |
+
points: torch.Tensor,
|
76 |
+
labels: torch.Tensor,
|
77 |
+
pad: bool,
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""Embeds point prompts."""
|
80 |
+
points = points + 0.5 # Shift to center of pixel
|
81 |
+
if pad:
|
82 |
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
83 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
84 |
+
points = torch.cat([points, padding_point], dim=1)
|
85 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
86 |
+
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
87 |
+
point_embedding[labels == -1] = 0.0
|
88 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
89 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
90 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
91 |
+
return point_embedding
|
92 |
+
|
93 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
94 |
+
"""Embeds box prompts."""
|
95 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
96 |
+
coords = boxes.reshape(-1, 2, 2)
|
97 |
+
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
98 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
99 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
100 |
+
return corner_embedding
|
101 |
+
|
102 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
103 |
+
"""Embeds mask inputs."""
|
104 |
+
mask_embedding = self.mask_downscaling(masks)
|
105 |
+
return mask_embedding
|
106 |
+
|
107 |
+
def _get_batch_size(
|
108 |
+
self,
|
109 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
110 |
+
boxes: Optional[torch.Tensor],
|
111 |
+
masks: Optional[torch.Tensor],
|
112 |
+
) -> int:
|
113 |
+
"""
|
114 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
115 |
+
"""
|
116 |
+
if points is not None:
|
117 |
+
return points[0].shape[0]
|
118 |
+
elif boxes is not None:
|
119 |
+
return boxes.shape[0]
|
120 |
+
elif masks is not None:
|
121 |
+
return masks.shape[0]
|
122 |
+
else:
|
123 |
+
return 1
|
124 |
+
|
125 |
+
def _get_device(self) -> torch.device:
|
126 |
+
return self.point_embeddings[0].weight.device
|
127 |
+
|
128 |
+
def forward(
|
129 |
+
self,
|
130 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
131 |
+
boxes: Optional[torch.Tensor],
|
132 |
+
masks: Optional[torch.Tensor],
|
133 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
134 |
+
"""
|
135 |
+
Embeds different types of prompts, returning both sparse and dense
|
136 |
+
embeddings.
|
137 |
+
|
138 |
+
Arguments:
|
139 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
140 |
+
and labels to embed.
|
141 |
+
boxes (torch.Tensor or none): boxes to embed
|
142 |
+
masks (torch.Tensor or none): masks to embed
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
146 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
147 |
+
and boxes.
|
148 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
149 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
150 |
+
"""
|
151 |
+
bs = self._get_batch_size(points, boxes, masks)
|
152 |
+
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
153 |
+
if points is not None:
|
154 |
+
coords, labels = points
|
155 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
156 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
157 |
+
if boxes is not None:
|
158 |
+
box_embeddings = self._embed_boxes(boxes)
|
159 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
160 |
+
|
161 |
+
if masks is not None:
|
162 |
+
dense_embeddings = self._embed_masks(masks)
|
163 |
+
else:
|
164 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
165 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
166 |
+
)
|
167 |
+
|
168 |
+
return sparse_embeddings, dense_embeddings
|
169 |
+
|
170 |
+
|
171 |
+
class PositionEmbeddingRandom(nn.Module):
|
172 |
+
"""
|
173 |
+
Positional encoding using random spatial frequencies.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
177 |
+
super().__init__()
|
178 |
+
if scale is None or scale <= 0.0:
|
179 |
+
scale = 1.0
|
180 |
+
self.register_buffer(
|
181 |
+
"positional_encoding_gaussian_matrix",
|
182 |
+
scale * torch.randn((2, num_pos_feats)),
|
183 |
+
)
|
184 |
+
|
185 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
186 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
187 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
188 |
+
coords = 2 * coords - 1
|
189 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
190 |
+
coords = 2 * np.pi * coords
|
191 |
+
# outputs d_1 x ... x d_n x C shape
|
192 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
193 |
+
|
194 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
195 |
+
"""Generate positional encoding for a grid of the specified size."""
|
196 |
+
h, w = size
|
197 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
198 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
199 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
200 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
201 |
+
y_embed = y_embed / h
|
202 |
+
x_embed = x_embed / w
|
203 |
+
|
204 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
205 |
+
return pe.permute(2, 0, 1) # C x H x W
|
206 |
+
|
207 |
+
def forward_with_coords(
|
208 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
209 |
+
) -> torch.Tensor:
|
210 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
211 |
+
coords = coords_input.clone()
|
212 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
213 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
214 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
tools/ins_seg/sam/sam_cls/segment_anything/modeling/sam.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from typing import Any, Dict, List, Tuple
|
12 |
+
|
13 |
+
from .image_encoder import ImageEncoderViT
|
14 |
+
from .mask_decoder import MaskDecoder
|
15 |
+
from .prompt_encoder import PromptEncoder
|
16 |
+
|
17 |
+
|
18 |
+
class Sam(nn.Module):
|
19 |
+
mask_threshold: float = 0.0
|
20 |
+
image_format: str = "RGB"
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
image_encoder: ImageEncoderViT,
|
25 |
+
prompt_encoder: PromptEncoder,
|
26 |
+
mask_decoder: MaskDecoder,
|
27 |
+
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
28 |
+
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
29 |
+
) -> None:
|
30 |
+
"""
|
31 |
+
SAM predicts object masks from an image and input prompts.
|
32 |
+
|
33 |
+
Arguments:
|
34 |
+
image_encoder (ImageEncoderViT): The backbone used to encode the
|
35 |
+
image into image embeddings that allow for efficient mask prediction.
|
36 |
+
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
37 |
+
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
38 |
+
and encoded prompts.
|
39 |
+
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
40 |
+
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
self.image_encoder = image_encoder
|
44 |
+
self.prompt_encoder = prompt_encoder
|
45 |
+
self.mask_decoder = mask_decoder
|
46 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
47 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
48 |
+
|
49 |
+
@property
|
50 |
+
def device(self) -> Any:
|
51 |
+
return self.pixel_mean.device
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
batched_input: List[Dict[str, Any]],
|
57 |
+
multimask_output: bool,
|
58 |
+
) -> List[Dict[str, torch.Tensor]]:
|
59 |
+
"""
|
60 |
+
Predicts masks end-to-end from provided images and prompts.
|
61 |
+
If prompts are not known in advance, using SamPredictor is
|
62 |
+
recommended over calling the model directly.
|
63 |
+
|
64 |
+
Arguments:
|
65 |
+
batched_input (list(dict)): A list over input images, each a
|
66 |
+
dictionary with the following keys. A prompt key can be
|
67 |
+
excluded if it is not present.
|
68 |
+
'image': The image as a torch tensor in 3xHxW format,
|
69 |
+
already transformed for input to the model.
|
70 |
+
'original_size': (tuple(int, int)) The original size of
|
71 |
+
the image before transformation, as (H, W).
|
72 |
+
'point_coords': (torch.Tensor) Batched point prompts for
|
73 |
+
this image, with shape BxNx2. Already transformed to the
|
74 |
+
input frame of the model.
|
75 |
+
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
76 |
+
with shape BxN.
|
77 |
+
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
78 |
+
Already transformed to the input frame of the model.
|
79 |
+
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
80 |
+
in the form Bx1xHxW.
|
81 |
+
multimask_output (bool): Whether the model should predict multiple
|
82 |
+
disambiguating masks, or return a single mask.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
(list(dict)): A list over input images, where each element is
|
86 |
+
as dictionary with the following keys.
|
87 |
+
'masks': (torch.Tensor) Batched binary mask predictions,
|
88 |
+
with shape BxCxHxW, where B is the number of input prompts,
|
89 |
+
C is determined by multimask_output, and (H, W) is the
|
90 |
+
original size of the image.
|
91 |
+
'iou_predictions': (torch.Tensor) The model's predictions
|
92 |
+
of mask quality, in shape BxC.
|
93 |
+
'low_res_logits': (torch.Tensor) Low resolution logits with
|
94 |
+
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
95 |
+
to subsequent iterations of prediction.
|
96 |
+
"""
|
97 |
+
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
98 |
+
image_embeddings = self.image_encoder(input_images)
|
99 |
+
|
100 |
+
outputs = []
|
101 |
+
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
102 |
+
if "point_coords" in image_record:
|
103 |
+
points = (image_record["point_coords"], image_record["point_labels"])
|
104 |
+
else:
|
105 |
+
points = None
|
106 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
107 |
+
points=points,
|
108 |
+
boxes=image_record.get("boxes", None),
|
109 |
+
masks=image_record.get("mask_inputs", None),
|
110 |
+
)
|
111 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
112 |
+
image_embeddings=curr_embedding.unsqueeze(0),
|
113 |
+
image_pe=self.prompt_encoder.get_dense_pe(),
|
114 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
115 |
+
dense_prompt_embeddings=dense_embeddings,
|
116 |
+
multimask_output=multimask_output,
|
117 |
+
)
|
118 |
+
masks = self.postprocess_masks(
|
119 |
+
low_res_masks,
|
120 |
+
input_size=image_record["image"].shape[-2:],
|
121 |
+
original_size=image_record["original_size"],
|
122 |
+
)
|
123 |
+
masks = masks > self.mask_threshold
|
124 |
+
outputs.append(
|
125 |
+
{
|
126 |
+
"masks": masks,
|
127 |
+
"iou_predictions": iou_predictions,
|
128 |
+
"low_res_logits": low_res_masks,
|
129 |
+
}
|
130 |
+
)
|
131 |
+
return outputs
|
132 |
+
|
133 |
+
def postprocess_masks(
|
134 |
+
self,
|
135 |
+
masks: torch.Tensor,
|
136 |
+
input_size: Tuple[int, ...],
|
137 |
+
original_size: Tuple[int, ...],
|
138 |
+
) -> torch.Tensor:
|
139 |
+
"""
|
140 |
+
Remove padding and upscale masks to the original image size.
|
141 |
+
|
142 |
+
Arguments:
|
143 |
+
masks (torch.Tensor): Batched masks from the mask_decoder,
|
144 |
+
in BxCxHxW format.
|
145 |
+
input_size (tuple(int, int)): The size of the image input to the
|
146 |
+
model, in (H, W) format. Used to remove padding.
|
147 |
+
original_size (tuple(int, int)): The original size of the image
|
148 |
+
before resizing for input to the model, in (H, W) format.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
152 |
+
is given by original_size.
|
153 |
+
"""
|
154 |
+
masks = F.interpolate(
|
155 |
+
masks,
|
156 |
+
(self.image_encoder.img_size, self.image_encoder.img_size),
|
157 |
+
mode="bilinear",
|
158 |
+
align_corners=False,
|
159 |
+
)
|
160 |
+
masks = masks[..., : input_size[0], : input_size[1]]
|
161 |
+
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
|
162 |
+
return masks
|
163 |
+
|
164 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
"""Normalize pixel values and pad to a square input."""
|
166 |
+
# Normalize colors
|
167 |
+
x = (x - self.pixel_mean) / self.pixel_std
|
168 |
+
|
169 |
+
# Pad
|
170 |
+
h, w = x.shape[-2:]
|
171 |
+
padh = self.image_encoder.img_size - h
|
172 |
+
padw = self.image_encoder.img_size - w
|
173 |
+
x = F.pad(x, (0, padw, 0, padh))
|
174 |
+
return x
|
tools/ins_seg/sam/sam_cls/segment_anything/modeling/transformer.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import Tensor, nn
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Tuple, Type
|
12 |
+
|
13 |
+
from .common import MLPBlock
|
14 |
+
|
15 |
+
|
16 |
+
class TwoWayTransformer(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
depth: int,
|
20 |
+
embedding_dim: int,
|
21 |
+
num_heads: int,
|
22 |
+
mlp_dim: int,
|
23 |
+
activation: Type[nn.Module] = nn.ReLU,
|
24 |
+
attention_downsample_rate: int = 2,
|
25 |
+
) -> None:
|
26 |
+
"""
|
27 |
+
A transformer decoder that attends to an input image using
|
28 |
+
queries whose positional embedding is supplied.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
depth (int): number of layers in the transformer
|
32 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
33 |
+
num_heads (int): the number of heads for multihead attention. Must
|
34 |
+
divide embedding_dim
|
35 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
36 |
+
activation (nn.Module): the activation to use in the MLP block
|
37 |
+
"""
|
38 |
+
super().__init__()
|
39 |
+
self.depth = depth
|
40 |
+
self.embedding_dim = embedding_dim
|
41 |
+
self.num_heads = num_heads
|
42 |
+
self.mlp_dim = mlp_dim
|
43 |
+
self.layers = nn.ModuleList()
|
44 |
+
|
45 |
+
for i in range(depth):
|
46 |
+
self.layers.append(
|
47 |
+
TwoWayAttentionBlock(
|
48 |
+
embedding_dim=embedding_dim,
|
49 |
+
num_heads=num_heads,
|
50 |
+
mlp_dim=mlp_dim,
|
51 |
+
activation=activation,
|
52 |
+
attention_downsample_rate=attention_downsample_rate,
|
53 |
+
skip_first_layer_pe=(i == 0),
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
self.final_attn_token_to_image = Attention(
|
58 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
59 |
+
)
|
60 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
image_embedding: Tensor,
|
65 |
+
image_pe: Tensor,
|
66 |
+
point_embedding: Tensor,
|
67 |
+
) -> Tuple[Tensor, Tensor]:
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
71 |
+
B x embedding_dim x h x w for any h and w.
|
72 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
73 |
+
have the same shape as image_embedding.
|
74 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
75 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: the processed point_embedding
|
79 |
+
torch.Tensor: the processed image_embedding
|
80 |
+
"""
|
81 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
82 |
+
bs, c, h, w = image_embedding.shape
|
83 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
84 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
85 |
+
|
86 |
+
# Prepare queries
|
87 |
+
queries = point_embedding
|
88 |
+
keys = image_embedding
|
89 |
+
|
90 |
+
# Apply transformer blocks and final layernorm
|
91 |
+
for layer in self.layers:
|
92 |
+
queries, keys = layer(
|
93 |
+
queries=queries,
|
94 |
+
keys=keys,
|
95 |
+
query_pe=point_embedding,
|
96 |
+
key_pe=image_pe,
|
97 |
+
)
|
98 |
+
|
99 |
+
# Apply the final attention layer from the points to the image
|
100 |
+
q = queries + point_embedding
|
101 |
+
k = keys + image_pe
|
102 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
103 |
+
queries = queries + attn_out
|
104 |
+
queries = self.norm_final_attn(queries)
|
105 |
+
|
106 |
+
return queries, keys
|
107 |
+
|
108 |
+
|
109 |
+
class TwoWayAttentionBlock(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
embedding_dim: int,
|
113 |
+
num_heads: int,
|
114 |
+
mlp_dim: int = 2048,
|
115 |
+
activation: Type[nn.Module] = nn.ReLU,
|
116 |
+
attention_downsample_rate: int = 2,
|
117 |
+
skip_first_layer_pe: bool = False,
|
118 |
+
) -> None:
|
119 |
+
"""
|
120 |
+
A transformer block with four layers: (1) self-attention of sparse
|
121 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
122 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
123 |
+
inputs.
|
124 |
+
|
125 |
+
Arguments:
|
126 |
+
embedding_dim (int): the channel dimension of the embeddings
|
127 |
+
num_heads (int): the number of heads in the attention layers
|
128 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
129 |
+
activation (nn.Module): the activation of the mlp block
|
130 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
131 |
+
"""
|
132 |
+
super().__init__()
|
133 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
134 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
135 |
+
|
136 |
+
self.cross_attn_token_to_image = Attention(
|
137 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
138 |
+
)
|
139 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
140 |
+
|
141 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
142 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
143 |
+
|
144 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
145 |
+
self.cross_attn_image_to_token = Attention(
|
146 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
147 |
+
)
|
148 |
+
|
149 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
150 |
+
|
151 |
+
def forward(
|
152 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
153 |
+
) -> Tuple[Tensor, Tensor]:
|
154 |
+
# Self attention block
|
155 |
+
if self.skip_first_layer_pe:
|
156 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
157 |
+
else:
|
158 |
+
q = queries + query_pe
|
159 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
160 |
+
queries = queries + attn_out
|
161 |
+
queries = self.norm1(queries)
|
162 |
+
|
163 |
+
# Cross attention block, tokens attending to image embedding
|
164 |
+
q = queries + query_pe
|
165 |
+
k = keys + key_pe
|
166 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
167 |
+
queries = queries + attn_out
|
168 |
+
queries = self.norm2(queries)
|
169 |
+
|
170 |
+
# MLP block
|
171 |
+
mlp_out = self.mlp(queries)
|
172 |
+
queries = queries + mlp_out
|
173 |
+
queries = self.norm3(queries)
|
174 |
+
|
175 |
+
# Cross attention block, image embedding attending to tokens
|
176 |
+
q = queries + query_pe
|
177 |
+
k = keys + key_pe
|
178 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
179 |
+
keys = keys + attn_out
|
180 |
+
keys = self.norm4(keys)
|
181 |
+
|
182 |
+
return queries, keys
|
183 |
+
|
184 |
+
|
185 |
+
class Attention(nn.Module):
|
186 |
+
"""
|
187 |
+
An attention layer that allows for downscaling the size of the embedding
|
188 |
+
after projection to queries, keys, and values.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
embedding_dim: int,
|
194 |
+
num_heads: int,
|
195 |
+
downsample_rate: int = 1,
|
196 |
+
) -> None:
|
197 |
+
super().__init__()
|
198 |
+
self.embedding_dim = embedding_dim
|
199 |
+
self.internal_dim = embedding_dim // downsample_rate
|
200 |
+
self.num_heads = num_heads
|
201 |
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
202 |
+
|
203 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
204 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
205 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
206 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
207 |
+
|
208 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
209 |
+
b, n, c = x.shape
|
210 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
211 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
212 |
+
|
213 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
214 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
215 |
+
x = x.transpose(1, 2)
|
216 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
217 |
+
|
218 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
219 |
+
# Input projections
|
220 |
+
q = self.q_proj(q)
|
221 |
+
k = self.k_proj(k)
|
222 |
+
v = self.v_proj(v)
|
223 |
+
|
224 |
+
# Separate into heads
|
225 |
+
q = self._separate_heads(q, self.num_heads)
|
226 |
+
k = self._separate_heads(k, self.num_heads)
|
227 |
+
v = self._separate_heads(v, self.num_heads)
|
228 |
+
|
229 |
+
# Attention
|
230 |
+
_, _, _, c_per_head = q.shape
|
231 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
232 |
+
attn = attn / math.sqrt(c_per_head)
|
233 |
+
attn = torch.softmax(attn, dim=-1)
|
234 |
+
|
235 |
+
# Get output
|
236 |
+
out = attn @ v
|
237 |
+
out = self._recombine_heads(out)
|
238 |
+
out = self.out_proj(out)
|
239 |
+
|
240 |
+
return out
|
tools/ins_seg/sam/sam_cls/segment_anything/predictor.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from segment_anything.modeling import Sam
|
11 |
+
|
12 |
+
from typing import Optional, Tuple
|
13 |
+
|
14 |
+
from .utils.transforms import ResizeLongestSide
|
15 |
+
|
16 |
+
|
17 |
+
class SamPredictor:
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
sam_model: Sam,
|
21 |
+
) -> None:
|
22 |
+
"""
|
23 |
+
Uses SAM to calculate the image embedding for an image, and then
|
24 |
+
allow repeated, efficient mask prediction given prompts.
|
25 |
+
|
26 |
+
Arguments:
|
27 |
+
sam_model (Sam): The model to use for mask prediction.
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.model = sam_model
|
31 |
+
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
|
32 |
+
self.reset_image()
|
33 |
+
|
34 |
+
def set_image(
|
35 |
+
self,
|
36 |
+
image: np.ndarray,
|
37 |
+
image_format: str = "RGB",
|
38 |
+
) -> None:
|
39 |
+
"""
|
40 |
+
Calculates the image embeddings for the provided image, allowing
|
41 |
+
masks to be predicted with the 'predict' method.
|
42 |
+
|
43 |
+
Arguments:
|
44 |
+
image (np.ndarray): The image for calculating masks. Expects an
|
45 |
+
image in HWC uint8 format, with pixel values in [0, 255].
|
46 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
47 |
+
"""
|
48 |
+
assert image_format in [
|
49 |
+
"RGB",
|
50 |
+
"BGR",
|
51 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
52 |
+
if image_format != self.model.image_format:
|
53 |
+
image = image[..., ::-1]
|
54 |
+
|
55 |
+
# Transform the image to the form expected by the model
|
56 |
+
input_image = self.transform.apply_image(image)
|
57 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
58 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
59 |
+
|
60 |
+
self.set_torch_image(input_image_torch, image.shape[:2])
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def set_torch_image(
|
64 |
+
self,
|
65 |
+
transformed_image: torch.Tensor,
|
66 |
+
original_image_size: Tuple[int, ...],
|
67 |
+
) -> None:
|
68 |
+
"""
|
69 |
+
Calculates the image embeddings for the provided image, allowing
|
70 |
+
masks to be predicted with the 'predict' method. Expects the input
|
71 |
+
image to be already transformed to the format expected by the model.
|
72 |
+
|
73 |
+
Arguments:
|
74 |
+
transformed_image (torch.Tensor): The input image, with shape
|
75 |
+
1x3xHxW, which has been transformed with ResizeLongestSide.
|
76 |
+
original_image_size (tuple(int, int)): The size of the image
|
77 |
+
before transformation, in (H, W) format.
|
78 |
+
"""
|
79 |
+
assert (
|
80 |
+
len(transformed_image.shape) == 4
|
81 |
+
and transformed_image.shape[1] == 3
|
82 |
+
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
|
83 |
+
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
|
84 |
+
self.reset_image()
|
85 |
+
|
86 |
+
self.original_size = original_image_size
|
87 |
+
self.input_size = tuple(transformed_image.shape[-2:])
|
88 |
+
input_image = self.model.preprocess(transformed_image)
|
89 |
+
self.features = self.model.image_encoder(input_image)
|
90 |
+
self.is_image_set = True
|
91 |
+
|
92 |
+
def predict(
|
93 |
+
self,
|
94 |
+
point_coords: Optional[np.ndarray] = None,
|
95 |
+
point_labels: Optional[np.ndarray] = None,
|
96 |
+
box: Optional[np.ndarray] = None,
|
97 |
+
mask_input: Optional[np.ndarray] = None,
|
98 |
+
multimask_output: bool = True,
|
99 |
+
return_logits: bool = False,
|
100 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
101 |
+
"""
|
102 |
+
Predict masks for the given input prompts, using the currently set image.
|
103 |
+
|
104 |
+
Arguments:
|
105 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
106 |
+
model. Each point is in (X,Y) in pixels.
|
107 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
108 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
109 |
+
background point.
|
110 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
111 |
+
model, in XYXY format.
|
112 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
113 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
114 |
+
for SAM, H=W=256.
|
115 |
+
multimask_output (bool): If true, the model will return three masks.
|
116 |
+
For ambiguous input prompts (such as a single click), this will often
|
117 |
+
produce better masks than a single prediction. If only a single
|
118 |
+
mask is needed, the model's predicted quality score can be used
|
119 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
120 |
+
input prompts, multimask_output=False can give better results.
|
121 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
122 |
+
instead of a binary mask.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
126 |
+
number of masks, and (H, W) is the original image size.
|
127 |
+
(np.ndarray): An array of length C containing the model's
|
128 |
+
predictions for the quality of each mask.
|
129 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
130 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
131 |
+
a subsequent iteration as mask input.
|
132 |
+
"""
|
133 |
+
if not self.is_image_set:
|
134 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
135 |
+
|
136 |
+
# Transform input prompts
|
137 |
+
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
|
138 |
+
if point_coords is not None:
|
139 |
+
assert (
|
140 |
+
point_labels is not None
|
141 |
+
), "point_labels must be supplied if point_coords is supplied."
|
142 |
+
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
143 |
+
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
144 |
+
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
145 |
+
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
146 |
+
if box is not None:
|
147 |
+
box = self.transform.apply_boxes(box, self.original_size)
|
148 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
149 |
+
box_torch = box_torch[None, :]
|
150 |
+
if mask_input is not None:
|
151 |
+
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
152 |
+
mask_input_torch = mask_input_torch[None, :, :, :]
|
153 |
+
|
154 |
+
masks, iou_predictions, low_res_masks = self.predict_torch(
|
155 |
+
coords_torch,
|
156 |
+
labels_torch,
|
157 |
+
box_torch,
|
158 |
+
mask_input_torch,
|
159 |
+
multimask_output,
|
160 |
+
return_logits=return_logits,
|
161 |
+
)
|
162 |
+
|
163 |
+
masks_np = masks[0].detach().cpu().numpy()
|
164 |
+
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
|
165 |
+
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
|
166 |
+
return masks_np, iou_predictions_np, low_res_masks_np
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def predict_torch(
|
170 |
+
self,
|
171 |
+
point_coords: Optional[torch.Tensor],
|
172 |
+
point_labels: Optional[torch.Tensor],
|
173 |
+
boxes: Optional[torch.Tensor] = None,
|
174 |
+
mask_input: Optional[torch.Tensor] = None,
|
175 |
+
multimask_output: bool = True,
|
176 |
+
return_logits: bool = False,
|
177 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
178 |
+
"""
|
179 |
+
Predict masks for the given input prompts, using the currently set image.
|
180 |
+
Input prompts are batched torch tensors and are expected to already be
|
181 |
+
transformed to the input frame using ResizeLongestSide.
|
182 |
+
|
183 |
+
Arguments:
|
184 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
185 |
+
model. Each point is in (X,Y) in pixels.
|
186 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
187 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
188 |
+
background point.
|
189 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
190 |
+
model, in XYXY format.
|
191 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
192 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
193 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
194 |
+
predict method do not need further transformation.
|
195 |
+
multimask_output (bool): If true, the model will return three masks.
|
196 |
+
For ambiguous input prompts (such as a single click), this will often
|
197 |
+
produce better masks than a single prediction. If only a single
|
198 |
+
mask is needed, the model's predicted quality score can be used
|
199 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
200 |
+
input prompts, multimask_output=False can give better results.
|
201 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
202 |
+
instead of a binary mask.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
206 |
+
number of masks, and (H, W) is the original image size.
|
207 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
208 |
+
predictions for the quality of each mask.
|
209 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
210 |
+
of masks and H=W=256. These low res logits can be passed to
|
211 |
+
a subsequent iteration as mask input.
|
212 |
+
"""
|
213 |
+
if not self.is_image_set:
|
214 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
215 |
+
|
216 |
+
if point_coords is not None:
|
217 |
+
points = (point_coords, point_labels)
|
218 |
+
else:
|
219 |
+
points = None
|
220 |
+
|
221 |
+
# Embed prompts
|
222 |
+
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
223 |
+
points=points,
|
224 |
+
boxes=boxes,
|
225 |
+
masks=mask_input,
|
226 |
+
)
|
227 |
+
|
228 |
+
# Predict masks
|
229 |
+
low_res_masks, iou_predictions = self.model.mask_decoder(
|
230 |
+
image_embeddings=self.features,
|
231 |
+
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
232 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
233 |
+
dense_prompt_embeddings=dense_embeddings,
|
234 |
+
multimask_output=multimask_output,
|
235 |
+
)
|
236 |
+
|
237 |
+
# Upscale the masks to the original image resolution
|
238 |
+
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
239 |
+
|
240 |
+
if not return_logits:
|
241 |
+
masks = masks > self.model.mask_threshold
|
242 |
+
|
243 |
+
return masks, iou_predictions, low_res_masks
|
244 |
+
|
245 |
+
def get_image_embedding(self) -> torch.Tensor:
|
246 |
+
"""
|
247 |
+
Returns the image embeddings for the currently set image, with
|
248 |
+
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
249 |
+
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
250 |
+
"""
|
251 |
+
if not self.is_image_set:
|
252 |
+
raise RuntimeError(
|
253 |
+
"An image must be set with .set_image(...) to generate an embedding."
|
254 |
+
)
|
255 |
+
assert self.features is not None, "Features must exist if an image has been set."
|
256 |
+
return self.features
|
257 |
+
|
258 |
+
@property
|
259 |
+
def device(self) -> torch.device:
|
260 |
+
return self.model.device
|
261 |
+
|
262 |
+
def reset_image(self) -> None:
|
263 |
+
"""Resets the currently set image."""
|
264 |
+
self.is_image_set = False
|
265 |
+
self.features = None
|
266 |
+
self.orig_h = None
|
267 |
+
self.orig_w = None
|
268 |
+
self.input_h = None
|
269 |
+
self.input_w = None
|
tools/ins_seg/sam/sam_cls/segment_anything/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
tools/ins_seg/sam/sam_cls/segment_anything/utils/amg.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import math
|
11 |
+
from copy import deepcopy
|
12 |
+
from itertools import product
|
13 |
+
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
14 |
+
|
15 |
+
|
16 |
+
class MaskData:
|
17 |
+
"""
|
18 |
+
A structure for storing masks and their related data in batched format.
|
19 |
+
Implements basic filtering and concatenation.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, **kwargs) -> None:
|
23 |
+
for v in kwargs.values():
|
24 |
+
assert isinstance(
|
25 |
+
v, (list, np.ndarray, torch.Tensor)
|
26 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
27 |
+
self._stats = dict(**kwargs)
|
28 |
+
|
29 |
+
def __setitem__(self, key: str, item: Any) -> None:
|
30 |
+
assert isinstance(
|
31 |
+
item, (list, np.ndarray, torch.Tensor)
|
32 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
33 |
+
self._stats[key] = item
|
34 |
+
|
35 |
+
def __delitem__(self, key: str) -> None:
|
36 |
+
del self._stats[key]
|
37 |
+
|
38 |
+
def __getitem__(self, key: str) -> Any:
|
39 |
+
return self._stats[key]
|
40 |
+
|
41 |
+
def items(self) -> ItemsView[str, Any]:
|
42 |
+
return self._stats.items()
|
43 |
+
|
44 |
+
def filter(self, keep: torch.Tensor) -> None:
|
45 |
+
for k, v in self._stats.items():
|
46 |
+
if v is None:
|
47 |
+
self._stats[k] = None
|
48 |
+
elif isinstance(v, torch.Tensor):
|
49 |
+
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
50 |
+
elif isinstance(v, np.ndarray):
|
51 |
+
self._stats[k] = v[keep.detach().cpu().numpy()]
|
52 |
+
elif isinstance(v, list) and keep.dtype == torch.bool:
|
53 |
+
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
54 |
+
elif isinstance(v, list):
|
55 |
+
self._stats[k] = [v[i] for i in keep]
|
56 |
+
else:
|
57 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
58 |
+
|
59 |
+
def cat(self, new_stats: "MaskData") -> None:
|
60 |
+
for k, v in new_stats.items():
|
61 |
+
if k not in self._stats or self._stats[k] is None:
|
62 |
+
self._stats[k] = deepcopy(v)
|
63 |
+
elif isinstance(v, torch.Tensor):
|
64 |
+
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
65 |
+
elif isinstance(v, np.ndarray):
|
66 |
+
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
67 |
+
elif isinstance(v, list):
|
68 |
+
self._stats[k] = self._stats[k] + deepcopy(v)
|
69 |
+
else:
|
70 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
71 |
+
|
72 |
+
def to_numpy(self) -> None:
|
73 |
+
for k, v in self._stats.items():
|
74 |
+
if isinstance(v, torch.Tensor):
|
75 |
+
self._stats[k] = v.detach().cpu().numpy()
|
76 |
+
|
77 |
+
|
78 |
+
def is_box_near_crop_edge(
|
79 |
+
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
80 |
+
) -> torch.Tensor:
|
81 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
82 |
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
83 |
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
84 |
+
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
85 |
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
86 |
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
87 |
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
88 |
+
return torch.any(near_crop_edge, dim=1)
|
89 |
+
|
90 |
+
|
91 |
+
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
92 |
+
box_xywh = deepcopy(box_xyxy)
|
93 |
+
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
94 |
+
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
95 |
+
return box_xywh
|
96 |
+
|
97 |
+
|
98 |
+
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
99 |
+
assert len(args) > 0 and all(
|
100 |
+
len(a) == len(args[0]) for a in args
|
101 |
+
), "Batched iteration must have inputs of all the same size."
|
102 |
+
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
103 |
+
for b in range(n_batches):
|
104 |
+
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
105 |
+
|
106 |
+
|
107 |
+
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
108 |
+
"""
|
109 |
+
Encodes masks to an uncompressed RLE, in the format expected by
|
110 |
+
pycoco tools.
|
111 |
+
"""
|
112 |
+
# Put in fortran order and flatten h,w
|
113 |
+
b, h, w = tensor.shape
|
114 |
+
tensor = tensor.permute(0, 2, 1).flatten(1)
|
115 |
+
|
116 |
+
# Compute change indices
|
117 |
+
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
118 |
+
change_indices = diff.nonzero()
|
119 |
+
|
120 |
+
# Encode run length
|
121 |
+
out = []
|
122 |
+
for i in range(b):
|
123 |
+
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
124 |
+
cur_idxs = torch.cat(
|
125 |
+
[
|
126 |
+
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
127 |
+
cur_idxs + 1,
|
128 |
+
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
129 |
+
]
|
130 |
+
)
|
131 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
132 |
+
counts = [] if tensor[i, 0] == 0 else [0]
|
133 |
+
counts.extend(btw_idxs.detach().cpu().tolist())
|
134 |
+
out.append({"size": [h, w], "counts": counts})
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
139 |
+
"""Compute a binary mask from an uncompressed RLE."""
|
140 |
+
h, w = rle["size"]
|
141 |
+
mask = np.empty(h * w, dtype=bool)
|
142 |
+
idx = 0
|
143 |
+
parity = False
|
144 |
+
for count in rle["counts"]:
|
145 |
+
mask[idx : idx + count] = parity
|
146 |
+
idx += count
|
147 |
+
parity ^= True
|
148 |
+
mask = mask.reshape(w, h)
|
149 |
+
return mask.transpose() # Put in C order
|
150 |
+
|
151 |
+
|
152 |
+
def area_from_rle(rle: Dict[str, Any]) -> int:
|
153 |
+
return sum(rle["counts"][1::2])
|
154 |
+
|
155 |
+
|
156 |
+
def calculate_stability_score(
|
157 |
+
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
158 |
+
) -> torch.Tensor:
|
159 |
+
"""
|
160 |
+
Computes the stability score for a batch of masks. The stability
|
161 |
+
score is the IoU between the binary masks obtained by thresholding
|
162 |
+
the predicted mask logits at high and low values.
|
163 |
+
"""
|
164 |
+
# One mask is always contained inside the other.
|
165 |
+
# Save memory by preventing unnecessary cast to torch.int64
|
166 |
+
intersections = (
|
167 |
+
(masks > (mask_threshold + threshold_offset))
|
168 |
+
.sum(-1, dtype=torch.int16)
|
169 |
+
.sum(-1, dtype=torch.int32)
|
170 |
+
)
|
171 |
+
unions = (
|
172 |
+
(masks > (mask_threshold - threshold_offset))
|
173 |
+
.sum(-1, dtype=torch.int16)
|
174 |
+
.sum(-1, dtype=torch.int32)
|
175 |
+
)
|
176 |
+
return intersections / unions
|
177 |
+
|
178 |
+
|
179 |
+
def build_point_grid(n_per_side: int) -> np.ndarray:
|
180 |
+
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
181 |
+
offset = 1 / (2 * n_per_side)
|
182 |
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
183 |
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
184 |
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
185 |
+
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
186 |
+
return points
|
187 |
+
|
188 |
+
|
189 |
+
def build_all_layer_point_grids(
|
190 |
+
n_per_side: int, n_layers: int, scale_per_layer: int
|
191 |
+
) -> List[np.ndarray]:
|
192 |
+
"""Generates point grids for all crop layers."""
|
193 |
+
points_by_layer = []
|
194 |
+
for i in range(n_layers + 1):
|
195 |
+
n_points = int(n_per_side / (scale_per_layer**i))
|
196 |
+
points_by_layer.append(build_point_grid(n_points))
|
197 |
+
return points_by_layer
|
198 |
+
|
199 |
+
|
200 |
+
def generate_crop_boxes(
|
201 |
+
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
202 |
+
) -> Tuple[List[List[int]], List[int]]:
|
203 |
+
"""
|
204 |
+
Generates a list of crop boxes of different sizes. Each layer
|
205 |
+
has (2**i)**2 boxes for the ith layer.
|
206 |
+
"""
|
207 |
+
crop_boxes, layer_idxs = [], []
|
208 |
+
im_h, im_w = im_size
|
209 |
+
short_side = min(im_h, im_w)
|
210 |
+
|
211 |
+
# Original image
|
212 |
+
crop_boxes.append([0, 0, im_w, im_h])
|
213 |
+
layer_idxs.append(0)
|
214 |
+
|
215 |
+
def crop_len(orig_len, n_crops, overlap):
|
216 |
+
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
217 |
+
|
218 |
+
for i_layer in range(n_layers):
|
219 |
+
n_crops_per_side = 2 ** (i_layer + 1)
|
220 |
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
221 |
+
|
222 |
+
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
223 |
+
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
224 |
+
|
225 |
+
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
226 |
+
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
227 |
+
|
228 |
+
# Crops in XYWH format
|
229 |
+
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
230 |
+
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
231 |
+
crop_boxes.append(box)
|
232 |
+
layer_idxs.append(i_layer + 1)
|
233 |
+
|
234 |
+
return crop_boxes, layer_idxs
|
235 |
+
|
236 |
+
|
237 |
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
238 |
+
x0, y0, _, _ = crop_box
|
239 |
+
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
240 |
+
# Check if boxes has a channel dimension
|
241 |
+
if len(boxes.shape) == 3:
|
242 |
+
offset = offset.unsqueeze(1)
|
243 |
+
return boxes + offset
|
244 |
+
|
245 |
+
|
246 |
+
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
247 |
+
x0, y0, _, _ = crop_box
|
248 |
+
offset = torch.tensor([[x0, y0]], device=points.device)
|
249 |
+
# Check if points has a channel dimension
|
250 |
+
if len(points.shape) == 3:
|
251 |
+
offset = offset.unsqueeze(1)
|
252 |
+
return points + offset
|
253 |
+
|
254 |
+
|
255 |
+
def uncrop_masks(
|
256 |
+
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
257 |
+
) -> torch.Tensor:
|
258 |
+
x0, y0, x1, y1 = crop_box
|
259 |
+
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
260 |
+
return masks
|
261 |
+
# Coordinate transform masks
|
262 |
+
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
263 |
+
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
264 |
+
return torch.nn.functional.pad(masks, pad, value=0)
|
265 |
+
|
266 |
+
|
267 |
+
def remove_small_regions(
|
268 |
+
mask: np.ndarray, area_thresh: float, mode: str
|
269 |
+
) -> Tuple[np.ndarray, bool]:
|
270 |
+
"""
|
271 |
+
Removes small disconnected regions and holes in a mask. Returns the
|
272 |
+
mask and an indicator of if the mask has been modified.
|
273 |
+
"""
|
274 |
+
import cv2 # type: ignore
|
275 |
+
|
276 |
+
assert mode in ["holes", "islands"]
|
277 |
+
correct_holes = mode == "holes"
|
278 |
+
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
279 |
+
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
280 |
+
sizes = stats[:, -1][1:] # Row 0 is background label
|
281 |
+
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
282 |
+
if len(small_regions) == 0:
|
283 |
+
return mask, False
|
284 |
+
fill_labels = [0] + small_regions
|
285 |
+
if not correct_holes:
|
286 |
+
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
287 |
+
# If every region is below threshold, keep largest
|
288 |
+
if len(fill_labels) == 0:
|
289 |
+
fill_labels = [int(np.argmax(sizes)) + 1]
|
290 |
+
mask = np.isin(regions, fill_labels)
|
291 |
+
return mask, True
|
292 |
+
|
293 |
+
|
294 |
+
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
295 |
+
from pycocotools import mask as mask_utils # type: ignore
|
296 |
+
|
297 |
+
h, w = uncompressed_rle["size"]
|
298 |
+
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
299 |
+
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
300 |
+
return rle
|
301 |
+
|
302 |
+
|
303 |
+
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
304 |
+
"""
|
305 |
+
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
306 |
+
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
307 |
+
"""
|
308 |
+
# torch.max below raises an error on empty inputs, just skip in this case
|
309 |
+
if torch.numel(masks) == 0:
|
310 |
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
311 |
+
|
312 |
+
# Normalize shape to CxHxW
|
313 |
+
shape = masks.shape
|
314 |
+
h, w = shape[-2:]
|
315 |
+
if len(shape) > 2:
|
316 |
+
masks = masks.flatten(0, -3)
|
317 |
+
else:
|
318 |
+
masks = masks.unsqueeze(0)
|
319 |
+
|
320 |
+
# Get top and bottom edges
|
321 |
+
in_height, _ = torch.max(masks, dim=-1)
|
322 |
+
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
323 |
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
324 |
+
in_height_coords = in_height_coords + h * (~in_height)
|
325 |
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
326 |
+
|
327 |
+
# Get left and right edges
|
328 |
+
in_width, _ = torch.max(masks, dim=-2)
|
329 |
+
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
330 |
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
331 |
+
in_width_coords = in_width_coords + w * (~in_width)
|
332 |
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
333 |
+
|
334 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
335 |
+
# Replace these boxes with [0, 0, 0, 0]
|
336 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
337 |
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
338 |
+
out = out * (~empty_filter).unsqueeze(-1)
|
339 |
+
|
340 |
+
# Return to original shape
|
341 |
+
if len(shape) > 2:
|
342 |
+
out = out.reshape(*shape[:-2], 4)
|
343 |
+
else:
|
344 |
+
out = out[0]
|
345 |
+
|
346 |
+
return out
|
tools/ins_seg/sam/sam_cls/segment_anything/utils/onnx.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
from ..modeling import Sam
|
14 |
+
from .amg import calculate_stability_score
|
15 |
+
|
16 |
+
|
17 |
+
class SamOnnxModel(nn.Module):
|
18 |
+
"""
|
19 |
+
This model should not be called directly, but is used in ONNX export.
|
20 |
+
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
21 |
+
with some functions modified to enable model tracing. Also supports extra
|
22 |
+
options controlling what information. See the ONNX export script for details.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model: Sam,
|
28 |
+
return_single_mask: bool,
|
29 |
+
use_stability_score: bool = False,
|
30 |
+
return_extra_metrics: bool = False,
|
31 |
+
) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.mask_decoder = model.mask_decoder
|
34 |
+
self.model = model
|
35 |
+
self.img_size = model.image_encoder.img_size
|
36 |
+
self.return_single_mask = return_single_mask
|
37 |
+
self.use_stability_score = use_stability_score
|
38 |
+
self.stability_score_offset = 1.0
|
39 |
+
self.return_extra_metrics = return_extra_metrics
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def resize_longest_image_size(
|
43 |
+
input_image_size: torch.Tensor, longest_side: int
|
44 |
+
) -> torch.Tensor:
|
45 |
+
input_image_size = input_image_size.to(torch.float32)
|
46 |
+
scale = longest_side / torch.max(input_image_size)
|
47 |
+
transformed_size = scale * input_image_size
|
48 |
+
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
49 |
+
return transformed_size
|
50 |
+
|
51 |
+
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
52 |
+
point_coords = point_coords + 0.5
|
53 |
+
point_coords = point_coords / self.img_size
|
54 |
+
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
55 |
+
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
56 |
+
|
57 |
+
point_embedding = point_embedding * (point_labels != -1)
|
58 |
+
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
|
59 |
+
point_labels == -1
|
60 |
+
)
|
61 |
+
|
62 |
+
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
63 |
+
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
|
64 |
+
i
|
65 |
+
].weight * (point_labels == i)
|
66 |
+
|
67 |
+
return point_embedding
|
68 |
+
|
69 |
+
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
|
70 |
+
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
|
71 |
+
mask_embedding = mask_embedding + (
|
72 |
+
1 - has_mask_input
|
73 |
+
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
74 |
+
return mask_embedding
|
75 |
+
|
76 |
+
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
|
77 |
+
masks = F.interpolate(
|
78 |
+
masks,
|
79 |
+
size=(self.img_size, self.img_size),
|
80 |
+
mode="bilinear",
|
81 |
+
align_corners=False,
|
82 |
+
)
|
83 |
+
|
84 |
+
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
|
85 |
+
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
|
86 |
+
|
87 |
+
orig_im_size = orig_im_size.to(torch.int64)
|
88 |
+
h, w = orig_im_size[0], orig_im_size[1]
|
89 |
+
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
90 |
+
return masks
|
91 |
+
|
92 |
+
def select_masks(
|
93 |
+
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
94 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95 |
+
# Determine if we should return the multiclick mask or not from the number of points.
|
96 |
+
# The reweighting is used to avoid control flow.
|
97 |
+
score_reweight = torch.tensor(
|
98 |
+
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
99 |
+
).to(iou_preds.device)
|
100 |
+
score = iou_preds + (num_points - 2.5) * score_reweight
|
101 |
+
best_idx = torch.argmax(score, dim=1)
|
102 |
+
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
103 |
+
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
104 |
+
|
105 |
+
return masks, iou_preds
|
106 |
+
|
107 |
+
@torch.no_grad()
|
108 |
+
def forward(
|
109 |
+
self,
|
110 |
+
image_embeddings: torch.Tensor,
|
111 |
+
point_coords: torch.Tensor,
|
112 |
+
point_labels: torch.Tensor,
|
113 |
+
mask_input: torch.Tensor,
|
114 |
+
has_mask_input: torch.Tensor,
|
115 |
+
orig_im_size: torch.Tensor,
|
116 |
+
):
|
117 |
+
sparse_embedding = self._embed_points(point_coords, point_labels)
|
118 |
+
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
119 |
+
|
120 |
+
masks, scores = self.model.mask_decoder.predict_masks(
|
121 |
+
image_embeddings=image_embeddings,
|
122 |
+
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
123 |
+
sparse_prompt_embeddings=sparse_embedding,
|
124 |
+
dense_prompt_embeddings=dense_embedding,
|
125 |
+
)
|
126 |
+
|
127 |
+
if self.use_stability_score:
|
128 |
+
scores = calculate_stability_score(
|
129 |
+
masks, self.model.mask_threshold, self.stability_score_offset
|
130 |
+
)
|
131 |
+
|
132 |
+
if self.return_single_mask:
|
133 |
+
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
134 |
+
|
135 |
+
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
136 |
+
|
137 |
+
if self.return_extra_metrics:
|
138 |
+
stability_scores = calculate_stability_score(
|
139 |
+
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
140 |
+
)
|
141 |
+
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
142 |
+
return upscaled_masks, scores, stability_scores, areas, masks
|
143 |
+
|
144 |
+
return upscaled_masks, scores, masks
|
tools/ins_seg/sam/sam_cls/segment_anything/utils/transforms.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
11 |
+
|
12 |
+
from copy import deepcopy
|
13 |
+
from typing import Tuple
|
14 |
+
|
15 |
+
|
16 |
+
class ResizeLongestSide:
|
17 |
+
"""
|
18 |
+
Resizes images to the longest side 'target_length', as well as provides
|
19 |
+
methods for resizing coordinates and boxes. Provides methods for
|
20 |
+
transforming both numpy array and batched torch tensors.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, target_length: int) -> None:
|
24 |
+
self.target_length = target_length
|
25 |
+
|
26 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
27 |
+
"""
|
28 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
29 |
+
"""
|
30 |
+
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
31 |
+
return np.array(resize(to_pil_image(image), target_size))
|
32 |
+
|
33 |
+
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
34 |
+
"""
|
35 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
36 |
+
original image size in (H, W) format.
|
37 |
+
"""
|
38 |
+
old_h, old_w = original_size
|
39 |
+
new_h, new_w = self.get_preprocess_shape(
|
40 |
+
original_size[0], original_size[1], self.target_length
|
41 |
+
)
|
42 |
+
coords = deepcopy(coords).astype(float)
|
43 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
44 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
45 |
+
return coords
|
46 |
+
|
47 |
+
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
48 |
+
"""
|
49 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
50 |
+
in (H, W) format.
|
51 |
+
"""
|
52 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
53 |
+
return boxes.reshape(-1, 4)
|
54 |
+
|
55 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
56 |
+
"""
|
57 |
+
Expects batched images with shape BxCxHxW and float format. This
|
58 |
+
transformation may not exactly match apply_image. apply_image is
|
59 |
+
the transformation expected by the model.
|
60 |
+
"""
|
61 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
62 |
+
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
63 |
+
return F.interpolate(
|
64 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
65 |
+
)
|
66 |
+
|
67 |
+
def apply_coords_torch(
|
68 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
69 |
+
) -> torch.Tensor:
|
70 |
+
"""
|
71 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
72 |
+
original image size in (H, W) format.
|
73 |
+
"""
|
74 |
+
old_h, old_w = original_size
|
75 |
+
new_h, new_w = self.get_preprocess_shape(
|
76 |
+
original_size[0], original_size[1], self.target_length
|
77 |
+
)
|
78 |
+
coords = deepcopy(coords).to(torch.float)
|
79 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
80 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
81 |
+
return coords
|
82 |
+
|
83 |
+
def apply_boxes_torch(
|
84 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
85 |
+
) -> torch.Tensor:
|
86 |
+
"""
|
87 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
88 |
+
size in (H, W) format.
|
89 |
+
"""
|
90 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
91 |
+
return boxes.reshape(-1, 4)
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
|
95 |
+
"""
|
96 |
+
Compute the output size given input size and target long side length.
|
97 |
+
"""
|
98 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
99 |
+
newh, neww = oldh * scale, oldw * scale
|
100 |
+
neww = int(neww + 0.5)
|
101 |
+
newh = int(newh + 0.5)
|
102 |
+
return (newh, neww)
|
tools/predict.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.insert(0, sys.path[0]+'/..')
|
5 |
+
from mmengine.config import Config, DictAction
|
6 |
+
from mmengine.logging import print_log
|
7 |
+
from mmengine.runner import Runner
|
8 |
+
from mmpl.engine.runner import PLRunner
|
9 |
+
import os.path as osp
|
10 |
+
from mmpl.registry import RUNNERS
|
11 |
+
from mmpl.utils import register_all_modules
|
12 |
+
register_all_modules()
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(description='Train a pl model')
|
17 |
+
parser.add_argument('--config', default='configs/rsprompter/rsprompter_anchor_whu_config.py',
|
18 |
+
help='train config file path')
|
19 |
+
parser.add_argument('--status', default='predict', help='fit or test', choices=['fit', 'test', 'predict', 'validate'])
|
20 |
+
parser.add_argument('--ckpt-path',
|
21 |
+
default='pretrain/whu/last.ckpt',
|
22 |
+
help='checkpoint path')
|
23 |
+
parser.add_argument('--work-dir', default=None, help='the dir to save logs and mmpl')
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
|
27 |
+
|
28 |
+
def main():
|
29 |
+
args = parse_args()
|
30 |
+
cfg = Config.fromfile(args.config)
|
31 |
+
if args.work_dir is not None:
|
32 |
+
cfg.trainer_cfg['default_root_dir'] = args.work_dir
|
33 |
+
elif cfg.trainer_cfg.get('default_root_dir', None) is None:
|
34 |
+
# use config filename as default work_dir if cfg.work_dir is None
|
35 |
+
cfg.trainer_cfg['default_root_dir'] = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
|
36 |
+
cfg.trainer_cfg['logger'] = None
|
37 |
+
if 'runner_type' not in cfg:
|
38 |
+
runner = PLRunner.from_cfg(cfg)
|
39 |
+
else:
|
40 |
+
runner = RUNNERS.build(cfg)
|
41 |
+
runner.run(args.status, ckpt_path=args.ckpt_path)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
main()
|
46 |
+
|
tools/test.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.insert(0, sys.path[0]+'/..')
|
5 |
+
from mmengine.config import Config, DictAction
|
6 |
+
from mmengine.logging import print_log
|
7 |
+
from mmengine.runner import Runner
|
8 |
+
from mmpl.engine.runner import PLRunner
|
9 |
+
import os.path as osp
|
10 |
+
from mmpl.registry import RUNNERS
|
11 |
+
from mmpl.utils import register_all_modules
|
12 |
+
register_all_modules()
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(description='Train a pl model')
|
17 |
+
parser.add_argument('--config', default='configs/rsprompter/rsprompter_anchor_whu_config.py', help='train config file path')
|
18 |
+
parser.add_argument('--status', default='test', help='fit or test', choices=['fit', 'test', 'predict', 'validate'])
|
19 |
+
parser.add_argument('--ckpt-path', default='pretrain/last.pth',
|
20 |
+
help='checkpoint path')
|
21 |
+
parser.add_argument('--work-dir', default=None, help='the dir to save logs and mmpl')
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
def main():
|
26 |
+
args = parse_args()
|
27 |
+
cfg = Config.fromfile(args.config)
|
28 |
+
if args.work_dir is not None:
|
29 |
+
cfg.trainer_cfg['default_root_dir'] = args.work_dir
|
30 |
+
elif cfg.trainer_cfg.get('default_root_dir', None) is None:
|
31 |
+
# use config filename as default work_dir if cfg.work_dir is None
|
32 |
+
cfg.trainer_cfg['default_root_dir'] = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
|
33 |
+
|
34 |
+
if 'runner_type' not in cfg:
|
35 |
+
runner = PLRunner.from_cfg(cfg)
|
36 |
+
else:
|
37 |
+
runner = RUNNERS.build(cfg)
|
38 |
+
runner.run(args.status, ckpt_path=args.ckpt_path)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
main()
|
43 |
+
|
tools/train.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
sys.path.insert(0, sys.path[0]+'/..')
|
6 |
+
from mmengine.config import Config, DictAction
|
7 |
+
from mmpl.engine.runner import PLRunner
|
8 |
+
import os.path as osp
|
9 |
+
from mmpl.registry import RUNNERS
|
10 |
+
from mmpl.utils import register_all_modules
|
11 |
+
|
12 |
+
torch.set_float32_matmul_precision('high')
|
13 |
+
register_all_modules()
|
14 |
+
# TORCH_DISTRIBUTED_DEBUG=DETAIL
|
15 |
+
|
16 |
+
def parse_args():
|
17 |
+
parser = argparse.ArgumentParser(description='Train a pl model')
|
18 |
+
parser.add_argument('--config', default='configs/rsprompter/rsprompter_anchor_whu_config.py',
|
19 |
+
help='train config file path')
|
20 |
+
parser.add_argument('--is-debug', default=False, action='store_true', help='debug mode')
|
21 |
+
parser.add_argument('--ckpt-path', default=None, help='checkpoint path')
|
22 |
+
parser.add_argument('--status', default='fit', help='fit or test', choices=['fit', 'test', 'predict', 'validate'])
|
23 |
+
parser.add_argument('--work-dir', default=None, help='the dir to save logs and mmpl')
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
|
27 |
+
|
28 |
+
def main():
|
29 |
+
args = parse_args()
|
30 |
+
cfg = Config.fromfile(args.config)
|
31 |
+
if args.work_dir is not None:
|
32 |
+
cfg.trainer_cfg['default_root_dir'] = args.work_dir
|
33 |
+
elif cfg.trainer_cfg.get('default_root_dir', None) is None:
|
34 |
+
# use config filename as default work_dir if cfg.work_dir is None
|
35 |
+
cfg.trainer_cfg['default_root_dir'] = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
|
36 |
+
if args.is_debug:
|
37 |
+
cfg.trainer_cfg['fast_dev_run'] = True
|
38 |
+
cfg.trainer_cfg['logger'] = None
|
39 |
+
if 'runner_type' not in cfg:
|
40 |
+
runner = PLRunner.from_cfg(cfg)
|
41 |
+
else:
|
42 |
+
runner = RUNNERS.build(cfg)
|
43 |
+
runner.run(args.status, ckpt_path=args.ckpt_path)
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
main()
|
48 |
+
|
visualizer/test_img.jpg
ADDED