hikerxu commited on
Commit
483de47
1 Parent(s): 8a20308

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +25 -0
  2. .gitignore +135 -0
  3. .gitmodules +7 -0
  4. CITATION.cff +8 -0
  5. Dockerfile +30 -0
  6. EfficientSAM/EdgeSAM/common.py +118 -0
  7. EfficientSAM/EdgeSAM/rep_vit.py +370 -0
  8. EfficientSAM/EdgeSAM/setup_edge_sam.py +90 -0
  9. EfficientSAM/FastSAM/tools.py +413 -0
  10. EfficientSAM/LightHQSAM/example_light_hqsam.png +3 -0
  11. EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg +0 -0
  12. EfficientSAM/LightHQSAM/setup_light_hqsam.py +45 -0
  13. EfficientSAM/LightHQSAM/tiny_vit_sam.py +724 -0
  14. EfficientSAM/MobileSAM/setup_mobile_sam.py +44 -0
  15. EfficientSAM/MobileSAM/tiny_vit_sam.py +716 -0
  16. EfficientSAM/README.md +194 -0
  17. EfficientSAM/RepViTSAM/repvit.py +364 -0
  18. EfficientSAM/RepViTSAM/setup_repvit_sam.py +53 -0
  19. EfficientSAM/grounded_edge_sam.py +107 -0
  20. EfficientSAM/grounded_efficient_sam.py +118 -0
  21. EfficientSAM/grounded_fast_sam.py +141 -0
  22. EfficientSAM/grounded_light_hqsam.py +109 -0
  23. EfficientSAM/grounded_mobile_sam.py +145 -0
  24. EfficientSAM/grounded_repvit_sam.py +107 -0
  25. GroundingDINO/.asset/COCO.png +0 -0
  26. GroundingDINO/.asset/GD_GLIGEN.png +3 -0
  27. GroundingDINO/.asset/GD_SD.png +3 -0
  28. GroundingDINO/.asset/ODinW.png +0 -0
  29. GroundingDINO/.asset/arch.png +0 -0
  30. GroundingDINO/.asset/cats.png +0 -0
  31. GroundingDINO/.asset/hero_figure.png +3 -0
  32. GroundingDINO/LICENSE +201 -0
  33. GroundingDINO/README.md +163 -0
  34. GroundingDINO/demo/gradio_app.py +125 -0
  35. GroundingDINO/demo/inference_on_a_image.py +172 -0
  36. GroundingDINO/groundingdino/__init__.py +0 -0
  37. GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py +43 -0
  38. GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py +43 -0
  39. GroundingDINO/groundingdino/datasets/__init__.py +0 -0
  40. GroundingDINO/groundingdino/datasets/transforms.py +311 -0
  41. GroundingDINO/groundingdino/models/GroundingDINO/__init__.py +15 -0
  42. GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py +1 -0
  43. GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py +221 -0
  44. GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py +186 -0
  45. GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py +802 -0
  46. GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py +273 -0
  47. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h +64 -0
  48. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp +43 -0
  49. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h +35 -0
  50. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +156 -0
.gitattributes CHANGED
@@ -33,3 +33,28 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EfficientSAM/LightHQSAM/example_light_hqsam.png filter=lfs diff=lfs merge=lfs -text
37
+ GroundingDINO/.asset/GD_GLIGEN.png filter=lfs diff=lfs merge=lfs -text
38
+ GroundingDINO/.asset/GD_SD.png filter=lfs diff=lfs merge=lfs -text
39
+ GroundingDINO/.asset/hero_figure.png filter=lfs diff=lfs merge=lfs -text
40
+ VISAM/thirdparty/segment_anything/assets/masks1.png filter=lfs diff=lfs merge=lfs -text
41
+ VISAM/thirdparty/segment_anything/assets/notebook2.png filter=lfs diff=lfs merge=lfs -text
42
+ VISAM/visam.gif filter=lfs diff=lfs merge=lfs -text
43
+ assets/acoustics/gsam_whisper_inpainting_demo.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/acoustics/gsam_whisper_inpainting_pipeline.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/demo9.jpg filter=lfs diff=lfs merge=lfs -text
46
+ assets/gradio_demo.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/grounded_sam_demo3_demo4.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/grounded_sam_inpainting_demo.png filter=lfs diff=lfs merge=lfs -text
49
+ assets/grounded_sam_new_demo_image.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/mask_3dbox.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/osx/grounded_sam_osx_demo.png filter=lfs diff=lfs merge=lfs -text
52
+ assets/osx/grouned_sam_osx_demo.gif filter=lfs diff=lfs merge=lfs -text
53
+ assets/ram_grounded_sam_new.png filter=lfs diff=lfs merge=lfs -text
54
+ segment_anything/assets/masks1.png filter=lfs diff=lfs merge=lfs -text
55
+ segment_anything/assets/notebook2.png filter=lfs diff=lfs merge=lfs -text
56
+ voxelnext_3d_box/images/image_boxes1.png filter=lfs diff=lfs merge=lfs -text
57
+ voxelnext_3d_box/images/image_boxes2.png filter=lfs diff=lfs merge=lfs -text
58
+ voxelnext_3d_box/images/image_boxes3.png filter=lfs diff=lfs merge=lfs -text
59
+ voxelnext_3d_box/images/mask_box.png filter=lfs diff=lfs merge=lfs -text
60
+ voxelnext_3d_box/images/sam-voxelnext.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # checkpoint
132
+ *.pth
133
+ outputs/
134
+
135
+ .idea/
.gitmodules ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ [submodule "grounded-sam-osx"]
3
+ path = grounded-sam-osx
4
+ url = https://github.com/linjing7/grounded-sam-osx.git
5
+ [submodule "VISAM"]
6
+ path = VISAM
7
+ url = https://github.com/BingfengYan/VISAM
CITATION.cff ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ message: "If you use this software, please cite it as below."
3
+ authors:
4
+ - name: "Grounded-SAM Contributors"
5
+ title: "Grounded-Segment-Anything"
6
+ date-released: 2023-04-06
7
+ url: "https://github.com/IDEA-Research/Grounded-Segment-Anything"
8
+ license: Apache-2.0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
2
+
3
+ # Arguments to build Docker Image using CUDA
4
+ ARG USE_CUDA=0
5
+ ARG TORCH_ARCH=
6
+
7
+ ENV AM_I_DOCKER True
8
+ ENV BUILD_WITH_CUDA "${USE_CUDA}"
9
+ ENV TORCH_CUDA_ARCH_LIST "${TORCH_ARCH}"
10
+ ENV CUDA_HOME /usr/local/cuda-11.6/
11
+
12
+ RUN mkdir -p /home/appuser/Grounded-Segment-Anything
13
+ COPY . /home/appuser/Grounded-Segment-Anything/
14
+
15
+ RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
16
+ libsm6=2:* libxext6=2:* git=1:* nano=2.* \
17
+ vim=2:* -y \
18
+ && apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
19
+
20
+ WORKDIR /home/appuser/Grounded-Segment-Anything
21
+ RUN python -m pip install --no-cache-dir -e segment_anything
22
+
23
+ # When using build isolation, PyTorch with newer CUDA is installed and can't compile GroundingDINO
24
+ RUN python -m pip install --no-cache-dir wheel
25
+ RUN python -m pip install --no-cache-dir --no-build-isolation -e GroundingDINO
26
+
27
+ WORKDIR /home/appuser
28
+ RUN pip install --no-cache-dir diffusers[torch]==0.15.1 opencv-python==4.7.0.72 \
29
+ pycocotools==2.0.6 matplotlib==3.5.3 \
30
+ onnxruntime==1.14.1 onnx==1.13.1 ipykernel==6.16.2 scipy gradio openai
EfficientSAM/EdgeSAM/common.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from typing import Type
12
+
13
+
14
+ class MLPBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embedding_dim: int,
18
+ mlp_dim: int,
19
+ act: Type[nn.Module] = nn.GELU,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
23
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
24
+ self.act = act()
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ return self.lin2(self.act(self.lin1(x)))
28
+
29
+
30
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
31
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
32
+ class LayerNorm2d(nn.Module):
33
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(num_channels))
36
+ self.bias = nn.Parameter(torch.zeros(num_channels))
37
+ self.eps = eps
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ u = x.mean(1, keepdim=True)
41
+ s = (x - u).pow(2).mean(1, keepdim=True)
42
+ x = (x - u) / torch.sqrt(s + self.eps)
43
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
44
+ return x
45
+
46
+
47
+ def val2list(x: list or tuple or any, repeat_time=1) -> list:
48
+ if isinstance(x, (list, tuple)):
49
+ return list(x)
50
+ return [x for _ in range(repeat_time)]
51
+
52
+
53
+ def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:
54
+ x = val2list(x)
55
+
56
+ # repeat elements if necessary
57
+ if len(x) > 0:
58
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
59
+
60
+ return tuple(x)
61
+
62
+
63
+ def list_sum(x: list) -> any:
64
+ return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
65
+
66
+
67
+ def resize(
68
+ x: torch.Tensor,
69
+ size: any or None = None,
70
+ scale_factor=None,
71
+ mode: str = "bicubic",
72
+ align_corners: bool or None = False,
73
+ ) -> torch.Tensor:
74
+ if mode in ["bilinear", "bicubic"]:
75
+ return F.interpolate(
76
+ x,
77
+ size=size,
78
+ scale_factor=scale_factor,
79
+ mode=mode,
80
+ align_corners=align_corners,
81
+ )
82
+ elif mode in ["nearest", "area"]:
83
+ return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
84
+ else:
85
+ raise NotImplementedError(f"resize(mode={mode}) not implemented.")
86
+
87
+
88
+ class UpSampleLayer(nn.Module):
89
+ def __init__(
90
+ self,
91
+ mode="bicubic",
92
+ size=None,
93
+ factor=2,
94
+ align_corners=False,
95
+ ):
96
+ super(UpSampleLayer, self).__init__()
97
+ self.mode = mode
98
+ self.size = val2list(size, 2) if size is not None else None
99
+ self.factor = None if self.size is not None else factor
100
+ self.align_corners = align_corners
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ return resize(x, self.size, self.factor, self.mode, self.align_corners)
104
+
105
+
106
+ class OpSequential(nn.Module):
107
+ def __init__(self, op_list):
108
+ super(OpSequential, self).__init__()
109
+ valid_op_list = []
110
+ for op in op_list:
111
+ if op is not None:
112
+ valid_op_list.append(op)
113
+ self.op_list = nn.ModuleList(valid_op_list)
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ for op in self.op_list:
117
+ x = op(x)
118
+ return x
EfficientSAM/EdgeSAM/rep_vit.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from EdgeSAM.common import LayerNorm2d, UpSampleLayer, OpSequential
3
+
4
+ __all__ = ['rep_vit_m1', 'rep_vit_m2', 'rep_vit_m3', 'RepViT']
5
+
6
+ m1_cfgs = [
7
+ # k, t, c, SE, HS, s
8
+ [3, 2, 48, 1, 0, 1],
9
+ [3, 2, 48, 0, 0, 1],
10
+ [3, 2, 48, 0, 0, 1],
11
+ [3, 2, 96, 0, 0, 2],
12
+ [3, 2, 96, 1, 0, 1],
13
+ [3, 2, 96, 0, 0, 1],
14
+ [3, 2, 96, 0, 0, 1],
15
+ [3, 2, 192, 0, 1, 2],
16
+ [3, 2, 192, 1, 1, 1],
17
+ [3, 2, 192, 0, 1, 1],
18
+ [3, 2, 192, 1, 1, 1],
19
+ [3, 2, 192, 0, 1, 1],
20
+ [3, 2, 192, 1, 1, 1],
21
+ [3, 2, 192, 0, 1, 1],
22
+ [3, 2, 192, 1, 1, 1],
23
+ [3, 2, 192, 0, 1, 1],
24
+ [3, 2, 192, 1, 1, 1],
25
+ [3, 2, 192, 0, 1, 1],
26
+ [3, 2, 192, 1, 1, 1],
27
+ [3, 2, 192, 0, 1, 1],
28
+ [3, 2, 192, 1, 1, 1],
29
+ [3, 2, 192, 0, 1, 1],
30
+ [3, 2, 192, 0, 1, 1],
31
+ [3, 2, 384, 0, 1, 2],
32
+ [3, 2, 384, 1, 1, 1],
33
+ [3, 2, 384, 0, 1, 1]
34
+ ]
35
+
36
+ m2_cfgs = [
37
+ # k, t, c, SE, HS, s
38
+ [3, 2, 64, 1, 0, 1],
39
+ [3, 2, 64, 0, 0, 1],
40
+ [3, 2, 64, 0, 0, 1],
41
+ [3, 2, 128, 0, 0, 2],
42
+ [3, 2, 128, 1, 0, 1],
43
+ [3, 2, 128, 0, 0, 1],
44
+ [3, 2, 128, 0, 0, 1],
45
+ [3, 2, 256, 0, 1, 2],
46
+ [3, 2, 256, 1, 1, 1],
47
+ [3, 2, 256, 0, 1, 1],
48
+ [3, 2, 256, 1, 1, 1],
49
+ [3, 2, 256, 0, 1, 1],
50
+ [3, 2, 256, 1, 1, 1],
51
+ [3, 2, 256, 0, 1, 1],
52
+ [3, 2, 256, 1, 1, 1],
53
+ [3, 2, 256, 0, 1, 1],
54
+ [3, 2, 256, 1, 1, 1],
55
+ [3, 2, 256, 0, 1, 1],
56
+ [3, 2, 256, 1, 1, 1],
57
+ [3, 2, 256, 0, 1, 1],
58
+ [3, 2, 256, 0, 1, 1],
59
+ [3, 2, 512, 0, 1, 2],
60
+ [3, 2, 512, 1, 1, 1],
61
+ [3, 2, 512, 0, 1, 1]
62
+ ]
63
+
64
+ m3_cfgs = [
65
+ # k, t, c, SE, HS, s
66
+ [3, 2, 64, 1, 0, 1],
67
+ [3, 2, 64, 0, 0, 1],
68
+ [3, 2, 64, 1, 0, 1],
69
+ [3, 2, 64, 0, 0, 1],
70
+ [3, 2, 64, 0, 0, 1],
71
+ [3, 2, 128, 0, 0, 2],
72
+ [3, 2, 128, 1, 0, 1],
73
+ [3, 2, 128, 0, 0, 1],
74
+ [3, 2, 128, 1, 0, 1],
75
+ [3, 2, 128, 0, 0, 1],
76
+ [3, 2, 128, 0, 0, 1],
77
+ [3, 2, 256, 0, 1, 2],
78
+ [3, 2, 256, 1, 1, 1],
79
+ [3, 2, 256, 0, 1, 1],
80
+ [3, 2, 256, 1, 1, 1],
81
+ [3, 2, 256, 0, 1, 1],
82
+ [3, 2, 256, 1, 1, 1],
83
+ [3, 2, 256, 0, 1, 1],
84
+ [3, 2, 256, 1, 1, 1],
85
+ [3, 2, 256, 0, 1, 1],
86
+ [3, 2, 256, 1, 1, 1],
87
+ [3, 2, 256, 0, 1, 1],
88
+ [3, 2, 256, 1, 1, 1],
89
+ [3, 2, 256, 0, 1, 1],
90
+ [3, 2, 256, 1, 1, 1],
91
+ [3, 2, 256, 0, 1, 1],
92
+ [3, 2, 256, 1, 1, 1],
93
+ [3, 2, 256, 0, 1, 1],
94
+ [3, 2, 256, 1, 1, 1],
95
+ [3, 2, 256, 0, 1, 1],
96
+ [3, 2, 256, 0, 1, 1],
97
+ [3, 2, 512, 0, 1, 2],
98
+ [3, 2, 512, 1, 1, 1],
99
+ [3, 2, 512, 0, 1, 1]
100
+ ]
101
+
102
+
103
+ def _make_divisible(v, divisor, min_value=None):
104
+ """
105
+ This function is taken from the original tf repo.
106
+ It ensures that all layers have a channel number that is divisible by 8
107
+ It can be seen here:
108
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
109
+ :param v:
110
+ :param divisor:
111
+ :param min_value:
112
+ :return:
113
+ """
114
+ if min_value is None:
115
+ min_value = divisor
116
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
117
+ # Make sure that round down does not go down by more than 10%.
118
+ if new_v < 0.9 * v:
119
+ new_v += divisor
120
+ return new_v
121
+
122
+
123
+ from timm.models.layers import SqueezeExcite
124
+
125
+ import torch
126
+
127
+
128
+ class Conv2d_BN(torch.nn.Sequential):
129
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
130
+ groups=1, bn_weight_init=1, resolution=-10000):
131
+ super().__init__()
132
+ self.add_module('c', torch.nn.Conv2d(
133
+ a, b, ks, stride, pad, dilation, groups, bias=False))
134
+ self.add_module('bn', torch.nn.BatchNorm2d(b))
135
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
136
+ torch.nn.init.constant_(self.bn.bias, 0)
137
+
138
+ @torch.no_grad()
139
+ def fuse(self):
140
+ c, bn = self._modules.values()
141
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
142
+ w = c.weight * w[:, None, None, None]
143
+ b = bn.bias - bn.running_mean * bn.weight / \
144
+ (bn.running_var + bn.eps) ** 0.5
145
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
146
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation,
147
+ groups=self.c.groups,
148
+ device=c.weight.device)
149
+ m.weight.data.copy_(w)
150
+ m.bias.data.copy_(b)
151
+ return m
152
+
153
+
154
+ class Residual(torch.nn.Module):
155
+ def __init__(self, m, drop=0.):
156
+ super().__init__()
157
+ self.m = m
158
+ self.drop = drop
159
+
160
+ def forward(self, x):
161
+ if self.training and self.drop > 0:
162
+ return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
163
+ device=x.device).ge_(self.drop).div(1 - self.drop).detach()
164
+ else:
165
+ return x + self.m(x)
166
+
167
+ @torch.no_grad()
168
+ def fuse(self):
169
+ if isinstance(self.m, Conv2d_BN):
170
+ m = self.m.fuse()
171
+ assert (m.groups == m.in_channels)
172
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
173
+ identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
174
+ m.weight += identity.to(m.weight.device)
175
+ return m
176
+ elif isinstance(self.m, torch.nn.Conv2d):
177
+ m = self.m
178
+ assert (m.groups != m.in_channels)
179
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
180
+ identity = torch.nn.functional.pad(identity, [1, 1, 1, 1])
181
+ m.weight += identity.to(m.weight.device)
182
+ return m
183
+ else:
184
+ return self
185
+
186
+
187
+ class RepVGGDW(torch.nn.Module):
188
+ def __init__(self, ed) -> None:
189
+ super().__init__()
190
+ self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
191
+ self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed)
192
+ self.dim = ed
193
+
194
+ def forward(self, x):
195
+ return self.conv(x) + self.conv1(x) + x
196
+
197
+ @torch.no_grad()
198
+ def fuse(self):
199
+ conv = self.conv.fuse()
200
+ conv1 = self.conv1.fuse()
201
+
202
+ conv_w = conv.weight
203
+ conv_b = conv.bias
204
+ conv1_w = conv1.weight
205
+ conv1_b = conv1.bias
206
+
207
+ conv1_w = torch.nn.functional.pad(conv1_w, [1, 1, 1, 1])
208
+
209
+ identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device),
210
+ [1, 1, 1, 1])
211
+
212
+ final_conv_w = conv_w + conv1_w + identity
213
+ final_conv_b = conv_b + conv1_b
214
+
215
+ conv.weight.data.copy_(final_conv_w)
216
+ conv.bias.data.copy_(final_conv_b)
217
+ return conv
218
+
219
+
220
+ class RepViTBlock(nn.Module):
221
+ def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs, skip_downsample=False):
222
+ super(RepViTBlock, self).__init__()
223
+ assert stride in [1, 2]
224
+
225
+ self.identity = stride == 1 and inp == oup
226
+ assert (hidden_dim == 2 * inp)
227
+
228
+ if stride == 2:
229
+ if skip_downsample:
230
+ stride = 1
231
+ self.token_mixer = nn.Sequential(
232
+ Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
233
+ SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
234
+ Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
235
+ )
236
+ self.channel_mixer = Residual(nn.Sequential(
237
+ # pw
238
+ Conv2d_BN(oup, 2 * oup, 1, 1, 0),
239
+ nn.GELU() if use_hs else nn.GELU(),
240
+ # pw-linear
241
+ Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
242
+ ))
243
+ else:
244
+ assert (self.identity)
245
+ self.token_mixer = nn.Sequential(
246
+ RepVGGDW(inp),
247
+ SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
248
+ )
249
+ self.channel_mixer = Residual(nn.Sequential(
250
+ # pw
251
+ Conv2d_BN(inp, hidden_dim, 1, 1, 0),
252
+ nn.GELU() if use_hs else nn.GELU(),
253
+ # pw-linear
254
+ Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
255
+ ))
256
+
257
+ def forward(self, x):
258
+ return self.channel_mixer(self.token_mixer(x))
259
+
260
+
261
+ from timm.models.vision_transformer import trunc_normal_
262
+
263
+
264
+ class BN_Linear(torch.nn.Sequential):
265
+ def __init__(self, a, b, bias=True, std=0.02):
266
+ super().__init__()
267
+ self.add_module('bn', torch.nn.BatchNorm1d(a))
268
+ self.add_module('l', torch.nn.Linear(a, b, bias=bias))
269
+ trunc_normal_(self.l.weight, std=std)
270
+ if bias:
271
+ torch.nn.init.constant_(self.l.bias, 0)
272
+
273
+ @torch.no_grad()
274
+ def fuse(self):
275
+ bn, l = self._modules.values()
276
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
277
+ b = bn.bias - self.bn.running_mean * \
278
+ self.bn.weight / (bn.running_var + bn.eps) ** 0.5
279
+ w = l.weight * w[None, :]
280
+ if l.bias is None:
281
+ b = b @ self.l.weight.T
282
+ else:
283
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
284
+ m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
285
+ m.weight.data.copy_(w)
286
+ m.bias.data.copy_(b)
287
+ return m
288
+
289
+
290
+ class RepViT(nn.Module):
291
+ arch_settings = {
292
+ 'm1': m1_cfgs,
293
+ 'm2': m2_cfgs,
294
+ 'm3': m3_cfgs
295
+ }
296
+
297
+ def __init__(self, arch, img_size=1024, upsample_mode='bicubic'):
298
+ super(RepViT, self).__init__()
299
+ # setting of inverted residual blocks
300
+ self.cfgs = self.arch_settings[arch]
301
+ self.img_size = img_size
302
+
303
+ # building first layer
304
+ input_channel = self.cfgs[0][2]
305
+ patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
306
+ Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
307
+ layers = [patch_embed]
308
+ # building inverted residual blocks
309
+ block = RepViTBlock
310
+ self.stage_idx = []
311
+ prev_c = input_channel
312
+ for idx, (k, t, c, use_se, use_hs, s) in enumerate(self.cfgs):
313
+ output_channel = _make_divisible(c, 8)
314
+ exp_size = _make_divisible(input_channel * t, 8)
315
+ skip_downsample = False
316
+ if c != prev_c:
317
+ self.stage_idx.append(idx - 1)
318
+ prev_c = c
319
+ layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs, skip_downsample))
320
+ input_channel = output_channel
321
+ self.stage_idx.append(idx)
322
+ self.features = nn.ModuleList(layers)
323
+
324
+ stage2_channels = _make_divisible(self.cfgs[self.stage_idx[2]][2], 8)
325
+ stage3_channels = _make_divisible(self.cfgs[self.stage_idx[3]][2], 8)
326
+ self.fuse_stage2 = nn.Conv2d(stage2_channels, 256, kernel_size=1, bias=False)
327
+ self.fuse_stage3 = OpSequential([
328
+ nn.Conv2d(stage3_channels, 256, kernel_size=1, bias=False),
329
+ UpSampleLayer(factor=2, mode=upsample_mode),
330
+ ])
331
+
332
+ self.neck = nn.Sequential(
333
+ nn.Conv2d(256, 256, kernel_size=1, bias=False),
334
+ LayerNorm2d(256),
335
+ nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
336
+ LayerNorm2d(256),
337
+ )
338
+
339
+ def forward(self, x):
340
+ counter = 0
341
+ output_dict = dict()
342
+ # patch_embed
343
+ x = self.features[0](x)
344
+ output_dict['stem'] = x
345
+ # stages
346
+ for idx, f in enumerate(self.features[1:]):
347
+ x = f(x)
348
+ if idx in self.stage_idx:
349
+ output_dict[f'stage{counter}'] = x
350
+ counter += 1
351
+
352
+ x = self.fuse_stage2(output_dict['stage2']) + self.fuse_stage3(output_dict['stage3'])
353
+
354
+ x = self.neck(x)
355
+ # hack this place because we modified the predictor of SAM for HQ-SAM in
356
+ # segment_anything/segment_anything/predictor.py line 91 to return intern features of the backbone
357
+ # self.features, self.interm_features = self.model.image_encoder(input_image)
358
+ return x, None
359
+
360
+
361
+ def rep_vit_m1(img_size=1024, **kwargs):
362
+ return RepViT('m1', img_size, **kwargs)
363
+
364
+
365
+ def rep_vit_m2(img_size=1024, **kwargs):
366
+ return RepViT('m2', img_size, **kwargs)
367
+
368
+
369
+ def rep_vit_m3(img_size=1024, **kwargs):
370
+ return RepViT('m3', img_size, **kwargs)
EfficientSAM/EdgeSAM/setup_edge_sam.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from segment_anything.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12
+ from EdgeSAM.rep_vit import RepViT
13
+
14
+
15
+ prompt_embed_dim = 256
16
+ image_size = 1024
17
+ vit_patch_size = 16
18
+ image_embedding_size = image_size // vit_patch_size
19
+
20
+
21
+ def build_edge_sam(checkpoint=None, upsample_mode="bicubic"):
22
+ image_encoder = RepViT(
23
+ arch="m1",
24
+ img_size=image_size,
25
+ upsample_mode=upsample_mode
26
+ )
27
+ return _build_sam(image_encoder, checkpoint)
28
+
29
+
30
+ sam_model_registry = {
31
+ "default": build_edge_sam,
32
+ "edge_sam": build_edge_sam,
33
+ }
34
+
35
+ def _build_sam_encoder(
36
+ encoder_embed_dim,
37
+ encoder_depth,
38
+ encoder_num_heads,
39
+ encoder_global_attn_indexes,
40
+ ):
41
+ image_encoder = ImageEncoderViT(
42
+ depth=encoder_depth,
43
+ embed_dim=encoder_embed_dim,
44
+ img_size=image_size,
45
+ mlp_ratio=4,
46
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
47
+ num_heads=encoder_num_heads,
48
+ patch_size=vit_patch_size,
49
+ qkv_bias=True,
50
+ use_rel_pos=True,
51
+ global_attn_indexes=encoder_global_attn_indexes,
52
+ window_size=14,
53
+ out_chans=prompt_embed_dim,
54
+ )
55
+ return image_encoder
56
+
57
+
58
+ def _build_sam(
59
+ image_encoder,
60
+ checkpoint=None,
61
+ ):
62
+ sam = Sam(
63
+ image_encoder=image_encoder,
64
+ prompt_encoder=PromptEncoder(
65
+ embed_dim=prompt_embed_dim,
66
+ image_embedding_size=(image_embedding_size, image_embedding_size),
67
+ input_image_size=(image_size, image_size),
68
+ mask_in_chans=16,
69
+ ),
70
+ mask_decoder=MaskDecoder(
71
+ num_multimask_outputs=3,
72
+ transformer=TwoWayTransformer(
73
+ depth=2,
74
+ embedding_dim=prompt_embed_dim,
75
+ mlp_dim=2048,
76
+ num_heads=8,
77
+ ),
78
+ transformer_dim=prompt_embed_dim,
79
+ iou_head_depth=3,
80
+ iou_head_hidden_dim=256,
81
+ ),
82
+ pixel_mean=[123.675, 116.28, 103.53],
83
+ pixel_std=[58.395, 57.12, 57.375],
84
+ )
85
+ sam.eval()
86
+ if checkpoint is not None:
87
+ with open(checkpoint, "rb") as f:
88
+ state_dict = torch.load(f, map_location="cpu")
89
+ sam.load_state_dict(state_dict)
90
+ return sam
EfficientSAM/FastSAM/tools.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+ import os
7
+ import clip
8
+
9
+
10
+ def convert_box_xywh_to_xyxy(box):
11
+ x1 = box[0]
12
+ y1 = box[1]
13
+ x2 = box[0] + box[2]
14
+ y2 = box[1] + box[3]
15
+ return [x1, y1, x2, y2]
16
+
17
+
18
+ def segment_image(image, bbox):
19
+ image_array = np.array(image)
20
+ segmented_image_array = np.zeros_like(image_array)
21
+ x1, y1, x2, y2 = bbox
22
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
23
+ segmented_image = Image.fromarray(segmented_image_array)
24
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
25
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
26
+ transparency_mask = np.zeros(
27
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
28
+ )
29
+ transparency_mask[y1:y2, x1:x2] = 255
30
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
31
+ black_image.paste(segmented_image, mask=transparency_mask_image)
32
+ return black_image
33
+
34
+
35
+ def format_results(result, filter=0):
36
+ annotations = []
37
+ n = len(result.masks.data)
38
+ for i in range(n):
39
+ annotation = {}
40
+ mask = result.masks.data[i] == 1.0
41
+
42
+ if torch.sum(mask) < filter:
43
+ continue
44
+ annotation["id"] = i
45
+ annotation["segmentation"] = mask.cpu().numpy()
46
+ annotation["bbox"] = result.boxes.data[i]
47
+ annotation["score"] = result.boxes.conf[i]
48
+ annotation["area"] = annotation["segmentation"].sum()
49
+ annotations.append(annotation)
50
+ return annotations
51
+
52
+
53
+ def filter_masks(annotations): # filte the overlap mask
54
+ annotations.sort(key=lambda x: x["area"], reverse=True)
55
+ to_remove = set()
56
+ for i in range(0, len(annotations)):
57
+ a = annotations[i]
58
+ for j in range(i + 1, len(annotations)):
59
+ b = annotations[j]
60
+ if i != j and j not in to_remove:
61
+ # check if
62
+ if b["area"] < a["area"]:
63
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
64
+ "segmentation"
65
+ ].sum() > 0.8:
66
+ to_remove.add(j)
67
+
68
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
69
+
70
+
71
+ def get_bbox_from_mask(mask):
72
+ mask = mask.astype(np.uint8)
73
+ contours, hierarchy = cv2.findContours(
74
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
75
+ )
76
+ x1, y1, w, h = cv2.boundingRect(contours[0])
77
+ x2, y2 = x1 + w, y1 + h
78
+ if len(contours) > 1:
79
+ for b in contours:
80
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
81
+ # 将多个bbox合并成一个
82
+ x1 = min(x1, x_t)
83
+ y1 = min(y1, y_t)
84
+ x2 = max(x2, x_t + w_t)
85
+ y2 = max(y2, y_t + h_t)
86
+ h = y2 - y1
87
+ w = x2 - x1
88
+ return [x1, y1, x2, y2]
89
+
90
+
91
+ def fast_process(
92
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
93
+ ):
94
+ if isinstance(annotations[0], dict):
95
+ annotations = [annotation["segmentation"] for annotation in annotations]
96
+ result_name = os.path.basename(args.img_path)
97
+ image = cv2.imread(args.img_path)
98
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99
+ original_h = image.shape[0]
100
+ original_w = image.shape[1]
101
+ plt.figure(figsize=(original_w/100, original_h/100))
102
+ plt.imshow(image)
103
+ if args.better_quality == True:
104
+ if isinstance(annotations[0], torch.Tensor):
105
+ annotations = np.array(annotations.cpu())
106
+ for i, mask in enumerate(annotations):
107
+ mask = cv2.morphologyEx(
108
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
109
+ )
110
+ annotations[i] = cv2.morphologyEx(
111
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
112
+ )
113
+ if args.device == "cpu":
114
+ annotations = np.array(annotations)
115
+ fast_show_mask(
116
+ annotations,
117
+ plt.gca(),
118
+ random_color=mask_random_color,
119
+ bbox=bbox,
120
+ points=points,
121
+ pointlabel=args.point_label,
122
+ retinamask=args.retina,
123
+ target_height=original_h,
124
+ target_width=original_w,
125
+ )
126
+ else:
127
+ if isinstance(annotations[0], np.ndarray):
128
+ annotations = torch.from_numpy(annotations)
129
+ fast_show_mask_gpu(
130
+ annotations,
131
+ plt.gca(),
132
+ random_color=args.randomcolor,
133
+ bbox=bbox,
134
+ points=points,
135
+ pointlabel=args.point_label,
136
+ retinamask=args.retina,
137
+ target_height=original_h,
138
+ target_width=original_w,
139
+ )
140
+ if isinstance(annotations, torch.Tensor):
141
+ annotations = annotations.cpu().numpy()
142
+ if args.withContours == True:
143
+ contour_all = []
144
+ temp = np.zeros((original_h, original_w, 1))
145
+ for i, mask in enumerate(annotations):
146
+ if type(mask) == dict:
147
+ mask = mask["segmentation"]
148
+ annotation = mask.astype(np.uint8)
149
+ if args.retina == False:
150
+ annotation = cv2.resize(
151
+ annotation,
152
+ (original_w, original_h),
153
+ interpolation=cv2.INTER_NEAREST,
154
+ )
155
+ contours, hierarchy = cv2.findContours(
156
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
157
+ )
158
+ for contour in contours:
159
+ contour_all.append(contour)
160
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
161
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
162
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
163
+ plt.imshow(contour_mask)
164
+
165
+ save_path = args.output
166
+ if not os.path.exists(save_path):
167
+ os.makedirs(save_path)
168
+ plt.axis("off")
169
+ fig = plt.gcf()
170
+ plt.draw()
171
+ buf = fig.canvas.tostring_rgb()
172
+ cols, rows = fig.canvas.get_width_height()
173
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
174
+ return img_array
175
+ # cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
176
+
177
+
178
+
179
+ # CPU post process
180
+ def fast_show_mask(
181
+ annotation,
182
+ ax,
183
+ random_color=False,
184
+ bbox=None,
185
+ points=None,
186
+ pointlabel=None,
187
+ retinamask=True,
188
+ target_height=960,
189
+ target_width=960,
190
+ ):
191
+ msak_sum = annotation.shape[0]
192
+ height = annotation.shape[1]
193
+ weight = annotation.shape[2]
194
+ # 将annotation 按照面积 排序
195
+ areas = np.sum(annotation, axis=(1, 2))
196
+ sorted_indices = np.argsort(areas)
197
+ annotation = annotation[sorted_indices]
198
+
199
+ index = (annotation != 0).argmax(axis=0)
200
+ if random_color == True:
201
+ color = np.random.random((msak_sum, 1, 1, 3))
202
+ else:
203
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
204
+ [30 / 255, 144 / 255, 255 / 255]
205
+ )
206
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
207
+ visual = np.concatenate([color, transparency], axis=-1)
208
+ mask_image = np.expand_dims(annotation, -1) * visual
209
+
210
+ show = np.zeros((height, weight, 4))
211
+ h_indices, w_indices = np.meshgrid(
212
+ np.arange(height), np.arange(weight), indexing="ij"
213
+ )
214
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
215
+ # 使用向量化索引更新show的值
216
+ show[h_indices, w_indices, :] = mask_image[indices]
217
+ if bbox is not None:
218
+ x1, y1, x2, y2 = bbox
219
+ ax.add_patch(
220
+ plt.Rectangle(
221
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
222
+ )
223
+ )
224
+ # draw point
225
+ if points is not None:
226
+ plt.scatter(
227
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
228
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
229
+ s=20,
230
+ c="y",
231
+ )
232
+ plt.scatter(
233
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
234
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
235
+ s=20,
236
+ c="m",
237
+ )
238
+
239
+ if retinamask == False:
240
+ show = cv2.resize(
241
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
242
+ )
243
+ ax.imshow(show)
244
+
245
+
246
+ def fast_show_mask_gpu(
247
+ annotation,
248
+ ax,
249
+ random_color=False,
250
+ bbox=None,
251
+ points=None,
252
+ pointlabel=None,
253
+ retinamask=True,
254
+ target_height=960,
255
+ target_width=960,
256
+ ):
257
+ msak_sum = annotation.shape[0]
258
+ height = annotation.shape[1]
259
+ weight = annotation.shape[2]
260
+ areas = torch.sum(annotation, dim=(1, 2))
261
+ sorted_indices = torch.argsort(areas, descending=False)
262
+ annotation = annotation[sorted_indices]
263
+ # 找每个位置第一个非零值下标
264
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
265
+ if random_color == True:
266
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
267
+ else:
268
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
269
+ [30 / 255, 144 / 255, 255 / 255]
270
+ ).to(annotation.device)
271
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
272
+ visual = torch.cat([color, transparency], dim=-1)
273
+ mask_image = torch.unsqueeze(annotation, -1) * visual
274
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
275
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
276
+ h_indices, w_indices = torch.meshgrid(
277
+ torch.arange(height), torch.arange(weight), indexing="ij"
278
+ )
279
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
280
+ # 使用向量化索引更新show的值
281
+ show[h_indices, w_indices, :] = mask_image[indices]
282
+ show_cpu = show.cpu().numpy()
283
+ if bbox is not None:
284
+ x1, y1, x2, y2 = bbox
285
+ ax.add_patch(
286
+ plt.Rectangle(
287
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
288
+ )
289
+ )
290
+ # draw point
291
+ if points is not None:
292
+ plt.scatter(
293
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
294
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
295
+ s=20,
296
+ c="y",
297
+ )
298
+ plt.scatter(
299
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
300
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
301
+ s=20,
302
+ c="m",
303
+ )
304
+ if retinamask == False:
305
+ show_cpu = cv2.resize(
306
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
307
+ )
308
+ ax.imshow(show_cpu)
309
+
310
+
311
+ # clip
312
+ @torch.no_grad()
313
+ def retriev(
314
+ model, preprocess, elements, search_text: str, device
315
+ ) -> int:
316
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
317
+ tokenized_text = clip.tokenize([search_text]).to(device)
318
+ stacked_images = torch.stack(preprocessed_images)
319
+ image_features = model.encode_image(stacked_images)
320
+ text_features = model.encode_text(tokenized_text)
321
+ image_features /= image_features.norm(dim=-1, keepdim=True)
322
+ text_features /= text_features.norm(dim=-1, keepdim=True)
323
+ probs = 100.0 * image_features @ text_features.T
324
+ return probs[:, 0].softmax(dim=0)
325
+
326
+
327
+ def crop_image(annotations, image_path):
328
+ image = Image.open(image_path)
329
+ ori_w, ori_h = image.size
330
+ mask_h, mask_w = annotations[0]["segmentation"].shape
331
+ if ori_w != mask_w or ori_h != mask_h:
332
+ image = image.resize((mask_w, mask_h))
333
+ cropped_boxes = []
334
+ cropped_images = []
335
+ not_crop = []
336
+ filter_id = []
337
+ # annotations, _ = filter_masks(annotations)
338
+ # filter_id = list(_)
339
+ for _, mask in enumerate(annotations):
340
+ if np.sum(mask["segmentation"]) <= 100:
341
+ filter_id.append(_)
342
+ continue
343
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
344
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
345
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
346
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
347
+
348
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
349
+
350
+
351
+ def box_prompt(masks, bbox, target_height, target_width):
352
+ h = masks.shape[1]
353
+ w = masks.shape[2]
354
+ if h != target_height or w != target_width:
355
+ bbox = [
356
+ int(bbox[0] * w / target_width),
357
+ int(bbox[1] * h / target_height),
358
+ int(bbox[2] * w / target_width),
359
+ int(bbox[3] * h / target_height),
360
+ ]
361
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
362
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
363
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
364
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
365
+
366
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
367
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
368
+
369
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
370
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
371
+
372
+ union = bbox_area + orig_masks_area - masks_area
373
+ IoUs = masks_area / union
374
+ max_iou_index = torch.argmax(IoUs)
375
+
376
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
377
+
378
+
379
+ def point_prompt(masks, points, pointlabel, target_height, target_width): # numpy 处理
380
+ h = masks[0]["segmentation"].shape[0]
381
+ w = masks[0]["segmentation"].shape[1]
382
+ if h != target_height or w != target_width:
383
+ points = [
384
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
385
+ for point in points
386
+ ]
387
+ onemask = np.zeros((h, w))
388
+ for i, annotation in enumerate(masks):
389
+ if type(annotation) == dict:
390
+ mask = annotation["segmentation"]
391
+ else:
392
+ mask = annotation
393
+ for i, point in enumerate(points):
394
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
395
+ onemask += mask
396
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
397
+ onemask -= mask
398
+ onemask = onemask >= 1
399
+ return onemask, 0
400
+
401
+
402
+ def text_prompt(annotations, args):
403
+ cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
404
+ annotations, args.img_path
405
+ )
406
+ clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
407
+ scores = retriev(
408
+ clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
409
+ )
410
+ max_idx = scores.argsort()
411
+ max_idx = max_idx[-1]
412
+ max_idx += sum(np.array(filter_id) <= int(max_idx))
413
+ return annotaions[max_idx]["segmentation"], max_idx
EfficientSAM/LightHQSAM/example_light_hqsam.png ADDED

Git LFS Details

  • SHA256: 866820ace9a150b791c00f955c2b436fc72a2e6a43b36187aba975be196161c4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg ADDED
EfficientSAM/LightHQSAM/setup_light_hqsam.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from LightHQSAM.tiny_vit_sam import TinyViT
2
+ from segment_anything.modeling import MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
3
+
4
+ def setup_model():
5
+ prompt_embed_dim = 256
6
+ image_size = 1024
7
+ vit_patch_size = 16
8
+ image_embedding_size = image_size // vit_patch_size
9
+ mobile_sam = Sam(
10
+ image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
11
+ embed_dims=[64, 128, 160, 320],
12
+ depths=[2, 2, 6, 2],
13
+ num_heads=[2, 4, 5, 10],
14
+ window_sizes=[7, 7, 14, 7],
15
+ mlp_ratio=4.,
16
+ drop_rate=0.,
17
+ drop_path_rate=0.0,
18
+ use_checkpoint=False,
19
+ mbconv_expand_ratio=4.0,
20
+ local_conv_size=3,
21
+ layer_lr_decay=0.8
22
+ ),
23
+ prompt_encoder=PromptEncoder(
24
+ embed_dim=prompt_embed_dim,
25
+ image_embedding_size=(image_embedding_size, image_embedding_size),
26
+ input_image_size=(image_size, image_size),
27
+ mask_in_chans=16,
28
+ ),
29
+ mask_decoder=MaskDecoderHQ(
30
+ num_multimask_outputs=3,
31
+ transformer=TwoWayTransformer(
32
+ depth=2,
33
+ embedding_dim=prompt_embed_dim,
34
+ mlp_dim=2048,
35
+ num_heads=8,
36
+ ),
37
+ transformer_dim=prompt_embed_dim,
38
+ iou_head_depth=3,
39
+ iou_head_hidden_dim=256,
40
+ vit_dim=160,
41
+ ),
42
+ pixel_mean=[123.675, 116.28, 103.53],
43
+ pixel_std=[58.395, 57.12, 57.375],
44
+ )
45
+ return mobile_sam
EfficientSAM/LightHQSAM/tiny_vit_sam.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # TinyViT Model Architecture
3
+ # Copyright (c) 2022 Microsoft
4
+ # Adapted from LeViT and Swin Transformer
5
+ # LeViT: (https://github.com/facebookresearch/levit)
6
+ # Swin: (https://github.com/microsoft/swin-transformer)
7
+ # Build the TinyViT Model
8
+ # --------------------------------------------------------
9
+
10
+ import itertools
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath as TimmDropPath,\
16
+ to_2tuple, trunc_normal_
17
+ from timm.models.registry import register_model
18
+ from typing import Tuple
19
+
20
+
21
+ class Conv2d_BN(torch.nn.Sequential):
22
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
23
+ groups=1, bn_weight_init=1):
24
+ super().__init__()
25
+ self.add_module('c', torch.nn.Conv2d(
26
+ a, b, ks, stride, pad, dilation, groups, bias=False))
27
+ bn = torch.nn.BatchNorm2d(b)
28
+ torch.nn.init.constant_(bn.weight, bn_weight_init)
29
+ torch.nn.init.constant_(bn.bias, 0)
30
+ self.add_module('bn', bn)
31
+
32
+ @torch.no_grad()
33
+ def fuse(self):
34
+ c, bn = self._modules.values()
35
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
36
+ w = c.weight * w[:, None, None, None]
37
+ b = bn.bias - bn.running_mean * bn.weight / \
38
+ (bn.running_var + bn.eps)**0.5
39
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
40
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
41
+ m.weight.data.copy_(w)
42
+ m.bias.data.copy_(b)
43
+ return m
44
+
45
+
46
+ class DropPath(TimmDropPath):
47
+ def __init__(self, drop_prob=None):
48
+ super().__init__(drop_prob=drop_prob)
49
+ self.drop_prob = drop_prob
50
+
51
+ def __repr__(self):
52
+ msg = super().__repr__()
53
+ msg += f'(drop_prob={self.drop_prob})'
54
+ return msg
55
+
56
+
57
+ class PatchEmbed(nn.Module):
58
+ def __init__(self, in_chans, embed_dim, resolution, activation):
59
+ super().__init__()
60
+ img_size: Tuple[int, int] = to_2tuple(resolution)
61
+ self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
62
+ self.num_patches = self.patches_resolution[0] * \
63
+ self.patches_resolution[1]
64
+ self.in_chans = in_chans
65
+ self.embed_dim = embed_dim
66
+ n = embed_dim
67
+ self.seq = nn.Sequential(
68
+ Conv2d_BN(in_chans, n // 2, 3, 2, 1),
69
+ activation(),
70
+ Conv2d_BN(n // 2, n, 3, 2, 1),
71
+ )
72
+
73
+ def forward(self, x):
74
+ return self.seq(x)
75
+
76
+
77
+ class MBConv(nn.Module):
78
+ def __init__(self, in_chans, out_chans, expand_ratio,
79
+ activation, drop_path):
80
+ super().__init__()
81
+ self.in_chans = in_chans
82
+ self.hidden_chans = int(in_chans * expand_ratio)
83
+ self.out_chans = out_chans
84
+
85
+ self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
86
+ self.act1 = activation()
87
+
88
+ self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans,
89
+ ks=3, stride=1, pad=1, groups=self.hidden_chans)
90
+ self.act2 = activation()
91
+
92
+ self.conv3 = Conv2d_BN(
93
+ self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
94
+ self.act3 = activation()
95
+
96
+ self.drop_path = DropPath(
97
+ drop_path) if drop_path > 0. else nn.Identity()
98
+
99
+ def forward(self, x):
100
+ shortcut = x
101
+
102
+ x = self.conv1(x)
103
+ x = self.act1(x)
104
+
105
+ x = self.conv2(x)
106
+ x = self.act2(x)
107
+
108
+ x = self.conv3(x)
109
+
110
+ x = self.drop_path(x)
111
+
112
+ x += shortcut
113
+ x = self.act3(x)
114
+
115
+ return x
116
+
117
+
118
+ class PatchMerging(nn.Module):
119
+ def __init__(self, input_resolution, dim, out_dim, activation):
120
+ super().__init__()
121
+
122
+ self.input_resolution = input_resolution
123
+ self.dim = dim
124
+ self.out_dim = out_dim
125
+ self.act = activation()
126
+ self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
127
+ stride_c=2
128
+ if(out_dim==320 or out_dim==448 or out_dim==576):
129
+ stride_c=1
130
+ self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
131
+ self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
132
+
133
+ def forward(self, x):
134
+ if x.ndim == 3:
135
+ H, W = self.input_resolution
136
+ B = len(x)
137
+ # (B, C, H, W)
138
+ x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
139
+
140
+ x = self.conv1(x)
141
+ x = self.act(x)
142
+
143
+ x = self.conv2(x)
144
+ x = self.act(x)
145
+ x = self.conv3(x)
146
+ x = x.flatten(2).transpose(1, 2)
147
+ return x
148
+
149
+
150
+ class ConvLayer(nn.Module):
151
+ def __init__(self, dim, input_resolution, depth,
152
+ activation,
153
+ drop_path=0., downsample=None, use_checkpoint=False,
154
+ out_dim=None,
155
+ conv_expand_ratio=4.,
156
+ ):
157
+
158
+ super().__init__()
159
+ self.dim = dim
160
+ self.input_resolution = input_resolution
161
+ self.depth = depth
162
+ self.use_checkpoint = use_checkpoint
163
+
164
+ # build blocks
165
+ self.blocks = nn.ModuleList([
166
+ MBConv(dim, dim, conv_expand_ratio, activation,
167
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
168
+ )
169
+ for i in range(depth)])
170
+
171
+ # patch merging layer
172
+ if downsample is not None:
173
+ self.downsample = downsample(
174
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation)
175
+ else:
176
+ self.downsample = None
177
+
178
+ def forward(self, x):
179
+ for blk in self.blocks:
180
+ if self.use_checkpoint:
181
+ x = checkpoint.checkpoint(blk, x)
182
+ else:
183
+ x = blk(x)
184
+ if self.downsample is not None:
185
+ x = self.downsample(x)
186
+ return x
187
+
188
+
189
+ class Mlp(nn.Module):
190
+ def __init__(self, in_features, hidden_features=None,
191
+ out_features=None, act_layer=nn.GELU, drop=0.):
192
+ super().__init__()
193
+ out_features = out_features or in_features
194
+ hidden_features = hidden_features or in_features
195
+ self.norm = nn.LayerNorm(in_features)
196
+ self.fc1 = nn.Linear(in_features, hidden_features)
197
+ self.fc2 = nn.Linear(hidden_features, out_features)
198
+ self.act = act_layer()
199
+ self.drop = nn.Dropout(drop)
200
+
201
+ def forward(self, x):
202
+ x = self.norm(x)
203
+
204
+ x = self.fc1(x)
205
+ x = self.act(x)
206
+ x = self.drop(x)
207
+ x = self.fc2(x)
208
+ x = self.drop(x)
209
+ return x
210
+
211
+
212
+ class Attention(torch.nn.Module):
213
+ def __init__(self, dim, key_dim, num_heads=8,
214
+ attn_ratio=4,
215
+ resolution=(14, 14),
216
+ ):
217
+ super().__init__()
218
+ # (h, w)
219
+ assert isinstance(resolution, tuple) and len(resolution) == 2
220
+ self.num_heads = num_heads
221
+ self.scale = key_dim ** -0.5
222
+ self.key_dim = key_dim
223
+ self.nh_kd = nh_kd = key_dim * num_heads
224
+ self.d = int(attn_ratio * key_dim)
225
+ self.dh = int(attn_ratio * key_dim) * num_heads
226
+ self.attn_ratio = attn_ratio
227
+ h = self.dh + nh_kd * 2
228
+
229
+ self.norm = nn.LayerNorm(dim)
230
+ self.qkv = nn.Linear(dim, h)
231
+ self.proj = nn.Linear(self.dh, dim)
232
+
233
+ points = list(itertools.product(
234
+ range(resolution[0]), range(resolution[1])))
235
+ N = len(points)
236
+ attention_offsets = {}
237
+ idxs = []
238
+ for p1 in points:
239
+ for p2 in points:
240
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
241
+ if offset not in attention_offsets:
242
+ attention_offsets[offset] = len(attention_offsets)
243
+ idxs.append(attention_offsets[offset])
244
+ self.attention_biases = torch.nn.Parameter(
245
+ torch.zeros(num_heads, len(attention_offsets)))
246
+ self.register_buffer('attention_bias_idxs',
247
+ torch.LongTensor(idxs).view(N, N),
248
+ persistent=False)
249
+
250
+ @torch.no_grad()
251
+ def train(self, mode=True):
252
+ super().train(mode)
253
+ if mode and hasattr(self, 'ab'):
254
+ del self.ab
255
+ else:
256
+ self.register_buffer('ab',
257
+ self.attention_biases[:, self.attention_bias_idxs],
258
+ persistent=False)
259
+
260
+ def forward(self, x): # x (B,N,C)
261
+ B, N, _ = x.shape
262
+
263
+ # Normalization
264
+ x = self.norm(x)
265
+
266
+ qkv = self.qkv(x)
267
+ # (B, N, num_heads, d)
268
+ q, k, v = qkv.view(B, N, self.num_heads, -
269
+ 1).split([self.key_dim, self.key_dim, self.d], dim=3)
270
+ # (B, num_heads, N, d)
271
+ q = q.permute(0, 2, 1, 3)
272
+ k = k.permute(0, 2, 1, 3)
273
+ v = v.permute(0, 2, 1, 3)
274
+
275
+ attn = (
276
+ (q @ k.transpose(-2, -1)) * self.scale
277
+ +
278
+ (self.attention_biases[:, self.attention_bias_idxs]
279
+ if self.training else self.ab)
280
+ )
281
+ attn = attn.softmax(dim=-1)
282
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
283
+ x = self.proj(x)
284
+ return x
285
+
286
+
287
+ class TinyViTBlock(nn.Module):
288
+ r""" TinyViT Block.
289
+
290
+ Args:
291
+ dim (int): Number of input channels.
292
+ input_resolution (tuple[int, int]): Input resolution.
293
+ num_heads (int): Number of attention heads.
294
+ window_size (int): Window size.
295
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
296
+ drop (float, optional): Dropout rate. Default: 0.0
297
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
298
+ local_conv_size (int): the kernel size of the convolution between
299
+ Attention and MLP. Default: 3
300
+ activation: the activation function. Default: nn.GELU
301
+ """
302
+
303
+ def __init__(self, dim, input_resolution, num_heads, window_size=7,
304
+ mlp_ratio=4., drop=0., drop_path=0.,
305
+ local_conv_size=3,
306
+ activation=nn.GELU,
307
+ ):
308
+ super().__init__()
309
+ self.dim = dim
310
+ self.input_resolution = input_resolution
311
+ self.num_heads = num_heads
312
+ assert window_size > 0, 'window_size must be greater than 0'
313
+ self.window_size = window_size
314
+ self.mlp_ratio = mlp_ratio
315
+
316
+ self.drop_path = DropPath(
317
+ drop_path) if drop_path > 0. else nn.Identity()
318
+
319
+ assert dim % num_heads == 0, 'dim must be divisible by num_heads'
320
+ head_dim = dim // num_heads
321
+
322
+ window_resolution = (window_size, window_size)
323
+ self.attn = Attention(dim, head_dim, num_heads,
324
+ attn_ratio=1, resolution=window_resolution)
325
+
326
+ mlp_hidden_dim = int(dim * mlp_ratio)
327
+ mlp_activation = activation
328
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
329
+ act_layer=mlp_activation, drop=drop)
330
+
331
+ pad = local_conv_size // 2
332
+ self.local_conv = Conv2d_BN(
333
+ dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
334
+
335
+ def forward(self, x):
336
+ H, W = self.input_resolution
337
+ B, L, C = x.shape
338
+ assert L == H * W, "input feature has wrong size"
339
+ res_x = x
340
+ if H == self.window_size and W == self.window_size:
341
+ x = self.attn(x)
342
+ else:
343
+ x = x.view(B, H, W, C)
344
+ pad_b = (self.window_size - H %
345
+ self.window_size) % self.window_size
346
+ pad_r = (self.window_size - W %
347
+ self.window_size) % self.window_size
348
+ padding = pad_b > 0 or pad_r > 0
349
+
350
+ if padding:
351
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
352
+
353
+ pH, pW = H + pad_b, W + pad_r
354
+ nH = pH // self.window_size
355
+ nW = pW // self.window_size
356
+ # window partition
357
+ x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
358
+ B * nH * nW, self.window_size * self.window_size, C)
359
+ x = self.attn(x)
360
+ # window reverse
361
+ x = x.view(B, nH, nW, self.window_size, self.window_size,
362
+ C).transpose(2, 3).reshape(B, pH, pW, C)
363
+
364
+ if padding:
365
+ x = x[:, :H, :W].contiguous()
366
+
367
+ x = x.view(B, L, C)
368
+
369
+ x = res_x + self.drop_path(x)
370
+
371
+ x = x.transpose(1, 2).reshape(B, C, H, W)
372
+ x = self.local_conv(x)
373
+ x = x.view(B, C, L).transpose(1, 2)
374
+
375
+ x = x + self.drop_path(self.mlp(x))
376
+ return x
377
+
378
+ def extra_repr(self) -> str:
379
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
380
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
381
+
382
+
383
+ class BasicLayer(nn.Module):
384
+ """ A basic TinyViT layer for one stage.
385
+
386
+ Args:
387
+ dim (int): Number of input channels.
388
+ input_resolution (tuple[int]): Input resolution.
389
+ depth (int): Number of blocks.
390
+ num_heads (int): Number of attention heads.
391
+ window_size (int): Local window size.
392
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
393
+ drop (float, optional): Dropout rate. Default: 0.0
394
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
395
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
396
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
397
+ local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
398
+ activation: the activation function. Default: nn.GELU
399
+ out_dim: the output dimension of the layer. Default: dim
400
+ """
401
+
402
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
403
+ mlp_ratio=4., drop=0.,
404
+ drop_path=0., downsample=None, use_checkpoint=False,
405
+ local_conv_size=3,
406
+ activation=nn.GELU,
407
+ out_dim=None,
408
+ ):
409
+
410
+ super().__init__()
411
+ self.dim = dim
412
+ self.input_resolution = input_resolution
413
+ self.depth = depth
414
+ self.use_checkpoint = use_checkpoint
415
+
416
+ # build blocks
417
+ self.blocks = nn.ModuleList([
418
+ TinyViTBlock(dim=dim, input_resolution=input_resolution,
419
+ num_heads=num_heads, window_size=window_size,
420
+ mlp_ratio=mlp_ratio,
421
+ drop=drop,
422
+ drop_path=drop_path[i] if isinstance(
423
+ drop_path, list) else drop_path,
424
+ local_conv_size=local_conv_size,
425
+ activation=activation,
426
+ )
427
+ for i in range(depth)])
428
+
429
+ # patch merging layer
430
+ if downsample is not None:
431
+ self.downsample = downsample(
432
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation)
433
+ else:
434
+ self.downsample = None
435
+
436
+ def forward(self, x):
437
+ for blk in self.blocks:
438
+ if self.use_checkpoint:
439
+ x = checkpoint.checkpoint(blk, x)
440
+ else:
441
+ x = blk(x)
442
+ if self.downsample is not None:
443
+ x = self.downsample(x)
444
+ return x
445
+
446
+ def extra_repr(self) -> str:
447
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
448
+
449
+ class LayerNorm2d(nn.Module):
450
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
451
+ super().__init__()
452
+ self.weight = nn.Parameter(torch.ones(num_channels))
453
+ self.bias = nn.Parameter(torch.zeros(num_channels))
454
+ self.eps = eps
455
+
456
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
457
+ u = x.mean(1, keepdim=True)
458
+ s = (x - u).pow(2).mean(1, keepdim=True)
459
+ x = (x - u) / torch.sqrt(s + self.eps)
460
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
461
+ return x
462
+ class TinyViT(nn.Module):
463
+ def __init__(self, img_size=224, in_chans=3, num_classes=1000,
464
+ embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
465
+ num_heads=[3, 6, 12, 24],
466
+ window_sizes=[7, 7, 14, 7],
467
+ mlp_ratio=4.,
468
+ drop_rate=0.,
469
+ drop_path_rate=0.1,
470
+ use_checkpoint=False,
471
+ mbconv_expand_ratio=4.0,
472
+ local_conv_size=3,
473
+ layer_lr_decay=1.0,
474
+ ):
475
+ super().__init__()
476
+ self.img_size=img_size
477
+ self.num_classes = num_classes
478
+ self.depths = depths
479
+ self.num_layers = len(depths)
480
+ self.mlp_ratio = mlp_ratio
481
+
482
+ activation = nn.GELU
483
+
484
+ self.patch_embed = PatchEmbed(in_chans=in_chans,
485
+ embed_dim=embed_dims[0],
486
+ resolution=img_size,
487
+ activation=activation)
488
+
489
+ patches_resolution = self.patch_embed.patches_resolution
490
+ self.patches_resolution = patches_resolution
491
+
492
+ # stochastic depth
493
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
494
+ sum(depths))] # stochastic depth decay rule
495
+
496
+ # build layers
497
+ self.layers = nn.ModuleList()
498
+ for i_layer in range(self.num_layers):
499
+ kwargs = dict(dim=embed_dims[i_layer],
500
+ input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
501
+ patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
502
+ # input_resolution=(patches_resolution[0] // (2 ** i_layer),
503
+ # patches_resolution[1] // (2 ** i_layer)),
504
+ depth=depths[i_layer],
505
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
506
+ downsample=PatchMerging if (
507
+ i_layer < self.num_layers - 1) else None,
508
+ use_checkpoint=use_checkpoint,
509
+ out_dim=embed_dims[min(
510
+ i_layer + 1, len(embed_dims) - 1)],
511
+ activation=activation,
512
+ )
513
+ if i_layer == 0:
514
+ layer = ConvLayer(
515
+ conv_expand_ratio=mbconv_expand_ratio,
516
+ **kwargs,
517
+ )
518
+ else:
519
+ layer = BasicLayer(
520
+ num_heads=num_heads[i_layer],
521
+ window_size=window_sizes[i_layer],
522
+ mlp_ratio=self.mlp_ratio,
523
+ drop=drop_rate,
524
+ local_conv_size=local_conv_size,
525
+ **kwargs)
526
+ self.layers.append(layer)
527
+
528
+ # Classifier head
529
+ self.norm_head = nn.LayerNorm(embed_dims[-1])
530
+ self.head = nn.Linear(
531
+ embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
532
+
533
+ # init weights
534
+ self.apply(self._init_weights)
535
+ self.set_layer_lr_decay(layer_lr_decay)
536
+ self.neck = nn.Sequential(
537
+ nn.Conv2d(
538
+ embed_dims[-1],
539
+ 256,
540
+ kernel_size=1,
541
+ bias=False,
542
+ ),
543
+ LayerNorm2d(256),
544
+ nn.Conv2d(
545
+ 256,
546
+ 256,
547
+ kernel_size=3,
548
+ padding=1,
549
+ bias=False,
550
+ ),
551
+ LayerNorm2d(256),
552
+ )
553
+ def set_layer_lr_decay(self, layer_lr_decay):
554
+ decay_rate = layer_lr_decay
555
+
556
+ # layers -> blocks (depth)
557
+ depth = sum(self.depths)
558
+ lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
559
+ #print("LR SCALES:", lr_scales)
560
+
561
+ def _set_lr_scale(m, scale):
562
+ for p in m.parameters():
563
+ p.lr_scale = scale
564
+
565
+ self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
566
+ i = 0
567
+ for layer in self.layers:
568
+ for block in layer.blocks:
569
+ block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
570
+ i += 1
571
+ if layer.downsample is not None:
572
+ layer.downsample.apply(
573
+ lambda x: _set_lr_scale(x, lr_scales[i - 1]))
574
+ assert i == depth
575
+ for m in [self.norm_head, self.head]:
576
+ m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
577
+
578
+ for k, p in self.named_parameters():
579
+ p.param_name = k
580
+
581
+ def _check_lr_scale(m):
582
+ for p in m.parameters():
583
+ assert hasattr(p, 'lr_scale'), p.param_name
584
+
585
+ self.apply(_check_lr_scale)
586
+
587
+ def _init_weights(self, m):
588
+ if isinstance(m, nn.Linear):
589
+ trunc_normal_(m.weight, std=.02)
590
+ if isinstance(m, nn.Linear) and m.bias is not None:
591
+ nn.init.constant_(m.bias, 0)
592
+ elif isinstance(m, nn.LayerNorm):
593
+ nn.init.constant_(m.bias, 0)
594
+ nn.init.constant_(m.weight, 1.0)
595
+
596
+ @torch.jit.ignore
597
+ def no_weight_decay_keywords(self):
598
+ return {'attention_biases'}
599
+
600
+ def forward_features(self, x):
601
+ # x: (N, C, H, W)
602
+ x = self.patch_embed(x)
603
+
604
+ x = self.layers[0](x)
605
+ start_i = 1
606
+
607
+ interm_embeddings=[]
608
+ for i in range(start_i, len(self.layers)):
609
+ layer = self.layers[i]
610
+ x = layer(x)
611
+ # print('x shape:', x.shape, '---i:', i)
612
+ if i == 1:
613
+ interm_embeddings.append(x.view(x.shape[0], 64, 64, -1))
614
+
615
+ B,_,C=x.size()
616
+ x = x.view(B, 64, 64, C)
617
+ x=x.permute(0, 3, 1, 2)
618
+ x=self.neck(x)
619
+ return x, interm_embeddings
620
+
621
+ def forward(self, x):
622
+ x, interm_embeddings = self.forward_features(x)
623
+ #x = self.norm_head(x)
624
+ #x = self.head(x)
625
+ # print('come to here is correct'* 3)
626
+ return x, interm_embeddings
627
+
628
+
629
+ _checkpoint_url_format = \
630
+ 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth'
631
+ _provided_checkpoints = {
632
+ 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill',
633
+ 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill',
634
+ 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill',
635
+ 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill',
636
+ 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill',
637
+ }
638
+
639
+
640
+ def register_tiny_vit_model(fn):
641
+ '''Register a TinyViT model
642
+ It is a wrapper of `register_model` with loading the pretrained checkpoint.
643
+ '''
644
+ def fn_wrapper(pretrained=False, **kwargs):
645
+ model = fn()
646
+ if pretrained:
647
+ model_name = fn.__name__
648
+ assert model_name in _provided_checkpoints, \
649
+ f'Sorry that the checkpoint `{model_name}` is not provided yet.'
650
+ url = _checkpoint_url_format.format(
651
+ _provided_checkpoints[model_name])
652
+ checkpoint = torch.hub.load_state_dict_from_url(
653
+ url=url,
654
+ map_location='cpu', check_hash=False,
655
+ )
656
+ model.load_state_dict(checkpoint['model'])
657
+
658
+ return model
659
+
660
+ # rename the name of fn_wrapper
661
+ fn_wrapper.__name__ = fn.__name__
662
+ return register_model(fn_wrapper)
663
+
664
+
665
+ @register_tiny_vit_model
666
+ def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0):
667
+ return TinyViT(
668
+ num_classes=num_classes,
669
+ embed_dims=[64, 128, 160, 320],
670
+ depths=[2, 2, 6, 2],
671
+ num_heads=[2, 4, 5, 10],
672
+ window_sizes=[7, 7, 14, 7],
673
+ drop_path_rate=drop_path_rate,
674
+ )
675
+
676
+
677
+ @register_tiny_vit_model
678
+ def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1):
679
+ return TinyViT(
680
+ num_classes=num_classes,
681
+ embed_dims=[64, 128, 256, 448],
682
+ depths=[2, 2, 6, 2],
683
+ num_heads=[2, 4, 8, 14],
684
+ window_sizes=[7, 7, 14, 7],
685
+ drop_path_rate=drop_path_rate,
686
+ )
687
+
688
+
689
+ @register_tiny_vit_model
690
+ def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2):
691
+ return TinyViT(
692
+ num_classes=num_classes,
693
+ embed_dims=[96, 192, 384, 576],
694
+ depths=[2, 2, 6, 2],
695
+ num_heads=[3, 6, 12, 18],
696
+ window_sizes=[7, 7, 14, 7],
697
+ drop_path_rate=drop_path_rate,
698
+ )
699
+
700
+
701
+ @register_tiny_vit_model
702
+ def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1):
703
+ return TinyViT(
704
+ img_size=384,
705
+ num_classes=num_classes,
706
+ embed_dims=[96, 192, 384, 576],
707
+ depths=[2, 2, 6, 2],
708
+ num_heads=[3, 6, 12, 18],
709
+ window_sizes=[12, 12, 24, 12],
710
+ drop_path_rate=drop_path_rate,
711
+ )
712
+
713
+
714
+ @register_tiny_vit_model
715
+ def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1):
716
+ return TinyViT(
717
+ img_size=512,
718
+ num_classes=num_classes,
719
+ embed_dims=[96, 192, 384, 576],
720
+ depths=[2, 2, 6, 2],
721
+ num_heads=[3, 6, 12, 18],
722
+ window_sizes=[16, 16, 32, 16],
723
+ drop_path_rate=drop_path_rate,
724
+ )
EfficientSAM/MobileSAM/setup_mobile_sam.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from MobileSAM.tiny_vit_sam import TinyViT
2
+ from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
3
+
4
+ def setup_model():
5
+ prompt_embed_dim = 256
6
+ image_size = 1024
7
+ vit_patch_size = 16
8
+ image_embedding_size = image_size // vit_patch_size
9
+ mobile_sam = Sam(
10
+ image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
11
+ embed_dims=[64, 128, 160, 320],
12
+ depths=[2, 2, 6, 2],
13
+ num_heads=[2, 4, 5, 10],
14
+ window_sizes=[7, 7, 14, 7],
15
+ mlp_ratio=4.,
16
+ drop_rate=0.,
17
+ drop_path_rate=0.0,
18
+ use_checkpoint=False,
19
+ mbconv_expand_ratio=4.0,
20
+ local_conv_size=3,
21
+ layer_lr_decay=0.8
22
+ ),
23
+ prompt_encoder=PromptEncoder(
24
+ embed_dim=prompt_embed_dim,
25
+ image_embedding_size=(image_embedding_size, image_embedding_size),
26
+ input_image_size=(image_size, image_size),
27
+ mask_in_chans=16,
28
+ ),
29
+ mask_decoder=MaskDecoder(
30
+ num_multimask_outputs=3,
31
+ transformer=TwoWayTransformer(
32
+ depth=2,
33
+ embedding_dim=prompt_embed_dim,
34
+ mlp_dim=2048,
35
+ num_heads=8,
36
+ ),
37
+ transformer_dim=prompt_embed_dim,
38
+ iou_head_depth=3,
39
+ iou_head_hidden_dim=256,
40
+ ),
41
+ pixel_mean=[123.675, 116.28, 103.53],
42
+ pixel_std=[58.395, 57.12, 57.375],
43
+ )
44
+ return mobile_sam
EfficientSAM/MobileSAM/tiny_vit_sam.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # TinyViT Model Architecture
3
+ # Copyright (c) 2022 Microsoft
4
+ # Adapted from LeViT and Swin Transformer
5
+ # LeViT: (https://github.com/facebookresearch/levit)
6
+ # Swin: (https://github.com/microsoft/swin-transformer)
7
+ # Build the TinyViT Model
8
+ # --------------------------------------------------------
9
+
10
+ import itertools
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath as TimmDropPath,\
16
+ to_2tuple, trunc_normal_
17
+ from timm.models.registry import register_model
18
+ from typing import Tuple
19
+
20
+
21
+ class Conv2d_BN(torch.nn.Sequential):
22
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
23
+ groups=1, bn_weight_init=1):
24
+ super().__init__()
25
+ self.add_module('c', torch.nn.Conv2d(
26
+ a, b, ks, stride, pad, dilation, groups, bias=False))
27
+ bn = torch.nn.BatchNorm2d(b)
28
+ torch.nn.init.constant_(bn.weight, bn_weight_init)
29
+ torch.nn.init.constant_(bn.bias, 0)
30
+ self.add_module('bn', bn)
31
+
32
+ @torch.no_grad()
33
+ def fuse(self):
34
+ c, bn = self._modules.values()
35
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
36
+ w = c.weight * w[:, None, None, None]
37
+ b = bn.bias - bn.running_mean * bn.weight / \
38
+ (bn.running_var + bn.eps)**0.5
39
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
40
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
41
+ m.weight.data.copy_(w)
42
+ m.bias.data.copy_(b)
43
+ return m
44
+
45
+
46
+ class DropPath(TimmDropPath):
47
+ def __init__(self, drop_prob=None):
48
+ super().__init__(drop_prob=drop_prob)
49
+ self.drop_prob = drop_prob
50
+
51
+ def __repr__(self):
52
+ msg = super().__repr__()
53
+ msg += f'(drop_prob={self.drop_prob})'
54
+ return msg
55
+
56
+
57
+ class PatchEmbed(nn.Module):
58
+ def __init__(self, in_chans, embed_dim, resolution, activation):
59
+ super().__init__()
60
+ img_size: Tuple[int, int] = to_2tuple(resolution)
61
+ self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
62
+ self.num_patches = self.patches_resolution[0] * \
63
+ self.patches_resolution[1]
64
+ self.in_chans = in_chans
65
+ self.embed_dim = embed_dim
66
+ n = embed_dim
67
+ self.seq = nn.Sequential(
68
+ Conv2d_BN(in_chans, n // 2, 3, 2, 1),
69
+ activation(),
70
+ Conv2d_BN(n // 2, n, 3, 2, 1),
71
+ )
72
+
73
+ def forward(self, x):
74
+ return self.seq(x)
75
+
76
+
77
+ class MBConv(nn.Module):
78
+ def __init__(self, in_chans, out_chans, expand_ratio,
79
+ activation, drop_path):
80
+ super().__init__()
81
+ self.in_chans = in_chans
82
+ self.hidden_chans = int(in_chans * expand_ratio)
83
+ self.out_chans = out_chans
84
+
85
+ self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
86
+ self.act1 = activation()
87
+
88
+ self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans,
89
+ ks=3, stride=1, pad=1, groups=self.hidden_chans)
90
+ self.act2 = activation()
91
+
92
+ self.conv3 = Conv2d_BN(
93
+ self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
94
+ self.act3 = activation()
95
+
96
+ self.drop_path = DropPath(
97
+ drop_path) if drop_path > 0. else nn.Identity()
98
+
99
+ def forward(self, x):
100
+ shortcut = x
101
+
102
+ x = self.conv1(x)
103
+ x = self.act1(x)
104
+
105
+ x = self.conv2(x)
106
+ x = self.act2(x)
107
+
108
+ x = self.conv3(x)
109
+
110
+ x = self.drop_path(x)
111
+
112
+ x += shortcut
113
+ x = self.act3(x)
114
+
115
+ return x
116
+
117
+
118
+ class PatchMerging(nn.Module):
119
+ def __init__(self, input_resolution, dim, out_dim, activation):
120
+ super().__init__()
121
+
122
+ self.input_resolution = input_resolution
123
+ self.dim = dim
124
+ self.out_dim = out_dim
125
+ self.act = activation()
126
+ self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
127
+ stride_c=2
128
+ if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576
129
+ stride_c=1
130
+ self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
131
+ self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
132
+
133
+ def forward(self, x):
134
+ if x.ndim == 3:
135
+ H, W = self.input_resolution
136
+ B = len(x)
137
+ # (B, C, H, W)
138
+ x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
139
+
140
+ x = self.conv1(x)
141
+ x = self.act(x)
142
+
143
+ x = self.conv2(x)
144
+ x = self.act(x)
145
+ x = self.conv3(x)
146
+ x = x.flatten(2).transpose(1, 2)
147
+ return x
148
+
149
+
150
+ class ConvLayer(nn.Module):
151
+ def __init__(self, dim, input_resolution, depth,
152
+ activation,
153
+ drop_path=0., downsample=None, use_checkpoint=False,
154
+ out_dim=None,
155
+ conv_expand_ratio=4.,
156
+ ):
157
+
158
+ super().__init__()
159
+ self.dim = dim
160
+ self.input_resolution = input_resolution
161
+ self.depth = depth
162
+ self.use_checkpoint = use_checkpoint
163
+
164
+ # build blocks
165
+ self.blocks = nn.ModuleList([
166
+ MBConv(dim, dim, conv_expand_ratio, activation,
167
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
168
+ )
169
+ for i in range(depth)])
170
+
171
+ # patch merging layer
172
+ if downsample is not None:
173
+ self.downsample = downsample(
174
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation)
175
+ else:
176
+ self.downsample = None
177
+
178
+ def forward(self, x):
179
+ for blk in self.blocks:
180
+ if self.use_checkpoint:
181
+ x = checkpoint.checkpoint(blk, x)
182
+ else:
183
+ x = blk(x)
184
+ if self.downsample is not None:
185
+ x = self.downsample(x)
186
+ return x
187
+
188
+
189
+ class Mlp(nn.Module):
190
+ def __init__(self, in_features, hidden_features=None,
191
+ out_features=None, act_layer=nn.GELU, drop=0.):
192
+ super().__init__()
193
+ out_features = out_features or in_features
194
+ hidden_features = hidden_features or in_features
195
+ self.norm = nn.LayerNorm(in_features)
196
+ self.fc1 = nn.Linear(in_features, hidden_features)
197
+ self.fc2 = nn.Linear(hidden_features, out_features)
198
+ self.act = act_layer()
199
+ self.drop = nn.Dropout(drop)
200
+
201
+ def forward(self, x):
202
+ x = self.norm(x)
203
+
204
+ x = self.fc1(x)
205
+ x = self.act(x)
206
+ x = self.drop(x)
207
+ x = self.fc2(x)
208
+ x = self.drop(x)
209
+ return x
210
+
211
+
212
+ class Attention(torch.nn.Module):
213
+ def __init__(self, dim, key_dim, num_heads=8,
214
+ attn_ratio=4,
215
+ resolution=(14, 14),
216
+ ):
217
+ super().__init__()
218
+ # (h, w)
219
+ assert isinstance(resolution, tuple) and len(resolution) == 2
220
+ self.num_heads = num_heads
221
+ self.scale = key_dim ** -0.5
222
+ self.key_dim = key_dim
223
+ self.nh_kd = nh_kd = key_dim * num_heads
224
+ self.d = int(attn_ratio * key_dim)
225
+ self.dh = int(attn_ratio * key_dim) * num_heads
226
+ self.attn_ratio = attn_ratio
227
+ h = self.dh + nh_kd * 2
228
+
229
+ self.norm = nn.LayerNorm(dim)
230
+ self.qkv = nn.Linear(dim, h)
231
+ self.proj = nn.Linear(self.dh, dim)
232
+
233
+ points = list(itertools.product(
234
+ range(resolution[0]), range(resolution[1])))
235
+ N = len(points)
236
+ attention_offsets = {}
237
+ idxs = []
238
+ for p1 in points:
239
+ for p2 in points:
240
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
241
+ if offset not in attention_offsets:
242
+ attention_offsets[offset] = len(attention_offsets)
243
+ idxs.append(attention_offsets[offset])
244
+ self.attention_biases = torch.nn.Parameter(
245
+ torch.zeros(num_heads, len(attention_offsets)))
246
+ self.register_buffer('attention_bias_idxs',
247
+ torch.LongTensor(idxs).view(N, N),
248
+ persistent=False)
249
+
250
+ @torch.no_grad()
251
+ def train(self, mode=True):
252
+ super().train(mode)
253
+ if mode and hasattr(self, 'ab'):
254
+ del self.ab
255
+ else:
256
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
257
+
258
+ def forward(self, x): # x (B,N,C)
259
+ B, N, _ = x.shape
260
+
261
+ # Normalization
262
+ x = self.norm(x)
263
+
264
+ qkv = self.qkv(x)
265
+ # (B, N, num_heads, d)
266
+ q, k, v = qkv.view(B, N, self.num_heads, -
267
+ 1).split([self.key_dim, self.key_dim, self.d], dim=3)
268
+ # (B, num_heads, N, d)
269
+ q = q.permute(0, 2, 1, 3)
270
+ k = k.permute(0, 2, 1, 3)
271
+ v = v.permute(0, 2, 1, 3)
272
+
273
+ attn = (
274
+ (q @ k.transpose(-2, -1)) * self.scale
275
+ +
276
+ (self.attention_biases[:, self.attention_bias_idxs]
277
+ if self.training else self.ab)
278
+ )
279
+ attn = attn.softmax(dim=-1)
280
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
281
+ x = self.proj(x)
282
+ return x
283
+
284
+
285
+ class TinyViTBlock(nn.Module):
286
+ r""" TinyViT Block.
287
+
288
+ Args:
289
+ dim (int): Number of input channels.
290
+ input_resolution (tuple[int, int]): Input resulotion.
291
+ num_heads (int): Number of attention heads.
292
+ window_size (int): Window size.
293
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
294
+ drop (float, optional): Dropout rate. Default: 0.0
295
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
296
+ local_conv_size (int): the kernel size of the convolution between
297
+ Attention and MLP. Default: 3
298
+ activation: the activation function. Default: nn.GELU
299
+ """
300
+
301
+ def __init__(self, dim, input_resolution, num_heads, window_size=7,
302
+ mlp_ratio=4., drop=0., drop_path=0.,
303
+ local_conv_size=3,
304
+ activation=nn.GELU,
305
+ ):
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.input_resolution = input_resolution
309
+ self.num_heads = num_heads
310
+ assert window_size > 0, 'window_size must be greater than 0'
311
+ self.window_size = window_size
312
+ self.mlp_ratio = mlp_ratio
313
+
314
+ self.drop_path = DropPath(
315
+ drop_path) if drop_path > 0. else nn.Identity()
316
+
317
+ assert dim % num_heads == 0, 'dim must be divisible by num_heads'
318
+ head_dim = dim // num_heads
319
+
320
+ window_resolution = (window_size, window_size)
321
+ self.attn = Attention(dim, head_dim, num_heads,
322
+ attn_ratio=1, resolution=window_resolution)
323
+
324
+ mlp_hidden_dim = int(dim * mlp_ratio)
325
+ mlp_activation = activation
326
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
327
+ act_layer=mlp_activation, drop=drop)
328
+
329
+ pad = local_conv_size // 2
330
+ self.local_conv = Conv2d_BN(
331
+ dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
332
+
333
+ def forward(self, x):
334
+ H, W = self.input_resolution
335
+ B, L, C = x.shape
336
+ assert L == H * W, "input feature has wrong size"
337
+ res_x = x
338
+ if H == self.window_size and W == self.window_size:
339
+ x = self.attn(x)
340
+ else:
341
+ x = x.view(B, H, W, C)
342
+ pad_b = (self.window_size - H %
343
+ self.window_size) % self.window_size
344
+ pad_r = (self.window_size - W %
345
+ self.window_size) % self.window_size
346
+ padding = pad_b > 0 or pad_r > 0
347
+
348
+ if padding:
349
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
350
+
351
+ pH, pW = H + pad_b, W + pad_r
352
+ nH = pH // self.window_size
353
+ nW = pW // self.window_size
354
+ # window partition
355
+ x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
356
+ B * nH * nW, self.window_size * self.window_size, C)
357
+ x = self.attn(x)
358
+ # window reverse
359
+ x = x.view(B, nH, nW, self.window_size, self.window_size,
360
+ C).transpose(2, 3).reshape(B, pH, pW, C)
361
+
362
+ if padding:
363
+ x = x[:, :H, :W].contiguous()
364
+
365
+ x = x.view(B, L, C)
366
+
367
+ x = res_x + self.drop_path(x)
368
+
369
+ x = x.transpose(1, 2).reshape(B, C, H, W)
370
+ x = self.local_conv(x)
371
+ x = x.view(B, C, L).transpose(1, 2)
372
+
373
+ x = x + self.drop_path(self.mlp(x))
374
+ return x
375
+
376
+ def extra_repr(self) -> str:
377
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
378
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
379
+
380
+
381
+ class BasicLayer(nn.Module):
382
+ """ A basic TinyViT layer for one stage.
383
+
384
+ Args:
385
+ dim (int): Number of input channels.
386
+ input_resolution (tuple[int]): Input resolution.
387
+ depth (int): Number of blocks.
388
+ num_heads (int): Number of attention heads.
389
+ window_size (int): Local window size.
390
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
391
+ drop (float, optional): Dropout rate. Default: 0.0
392
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
393
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
394
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
395
+ local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
396
+ activation: the activation function. Default: nn.GELU
397
+ out_dim: the output dimension of the layer. Default: dim
398
+ """
399
+
400
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
401
+ mlp_ratio=4., drop=0.,
402
+ drop_path=0., downsample=None, use_checkpoint=False,
403
+ local_conv_size=3,
404
+ activation=nn.GELU,
405
+ out_dim=None,
406
+ ):
407
+
408
+ super().__init__()
409
+ self.dim = dim
410
+ self.input_resolution = input_resolution
411
+ self.depth = depth
412
+ self.use_checkpoint = use_checkpoint
413
+
414
+ # build blocks
415
+ self.blocks = nn.ModuleList([
416
+ TinyViTBlock(dim=dim, input_resolution=input_resolution,
417
+ num_heads=num_heads, window_size=window_size,
418
+ mlp_ratio=mlp_ratio,
419
+ drop=drop,
420
+ drop_path=drop_path[i] if isinstance(
421
+ drop_path, list) else drop_path,
422
+ local_conv_size=local_conv_size,
423
+ activation=activation,
424
+ )
425
+ for i in range(depth)])
426
+
427
+ # patch merging layer
428
+ if downsample is not None:
429
+ self.downsample = downsample(
430
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation)
431
+ else:
432
+ self.downsample = None
433
+
434
+ def forward(self, x):
435
+ for blk in self.blocks:
436
+ if self.use_checkpoint:
437
+ x = checkpoint.checkpoint(blk, x)
438
+ else:
439
+ x = blk(x)
440
+ if self.downsample is not None:
441
+ x = self.downsample(x)
442
+ return x
443
+
444
+ def extra_repr(self) -> str:
445
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
446
+
447
+ class LayerNorm2d(nn.Module):
448
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
449
+ super().__init__()
450
+ self.weight = nn.Parameter(torch.ones(num_channels))
451
+ self.bias = nn.Parameter(torch.zeros(num_channels))
452
+ self.eps = eps
453
+
454
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
455
+ u = x.mean(1, keepdim=True)
456
+ s = (x - u).pow(2).mean(1, keepdim=True)
457
+ x = (x - u) / torch.sqrt(s + self.eps)
458
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
459
+ return x
460
+ class TinyViT(nn.Module):
461
+ def __init__(self, img_size=224, in_chans=3, num_classes=1000,
462
+ embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
463
+ num_heads=[3, 6, 12, 24],
464
+ window_sizes=[7, 7, 14, 7],
465
+ mlp_ratio=4.,
466
+ drop_rate=0.,
467
+ drop_path_rate=0.1,
468
+ use_checkpoint=False,
469
+ mbconv_expand_ratio=4.0,
470
+ local_conv_size=3,
471
+ layer_lr_decay=1.0,
472
+ ):
473
+ super().__init__()
474
+ self.img_size=img_size
475
+ self.num_classes = num_classes
476
+ self.depths = depths
477
+ self.num_layers = len(depths)
478
+ self.mlp_ratio = mlp_ratio
479
+
480
+ activation = nn.GELU
481
+
482
+ self.patch_embed = PatchEmbed(in_chans=in_chans,
483
+ embed_dim=embed_dims[0],
484
+ resolution=img_size,
485
+ activation=activation)
486
+
487
+ patches_resolution = self.patch_embed.patches_resolution
488
+ self.patches_resolution = patches_resolution
489
+
490
+ # stochastic depth
491
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
492
+ sum(depths))] # stochastic depth decay rule
493
+
494
+ # build layers
495
+ self.layers = nn.ModuleList()
496
+ for i_layer in range(self.num_layers):
497
+ kwargs = dict(dim=embed_dims[i_layer],
498
+ input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
499
+ patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
500
+ # input_resolution=(patches_resolution[0] // (2 ** i_layer),
501
+ # patches_resolution[1] // (2 ** i_layer)),
502
+ depth=depths[i_layer],
503
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
504
+ downsample=PatchMerging if (
505
+ i_layer < self.num_layers - 1) else None,
506
+ use_checkpoint=use_checkpoint,
507
+ out_dim=embed_dims[min(
508
+ i_layer + 1, len(embed_dims) - 1)],
509
+ activation=activation,
510
+ )
511
+ if i_layer == 0:
512
+ layer = ConvLayer(
513
+ conv_expand_ratio=mbconv_expand_ratio,
514
+ **kwargs,
515
+ )
516
+ else:
517
+ layer = BasicLayer(
518
+ num_heads=num_heads[i_layer],
519
+ window_size=window_sizes[i_layer],
520
+ mlp_ratio=self.mlp_ratio,
521
+ drop=drop_rate,
522
+ local_conv_size=local_conv_size,
523
+ **kwargs)
524
+ self.layers.append(layer)
525
+
526
+ # Classifier head
527
+ self.norm_head = nn.LayerNorm(embed_dims[-1])
528
+ self.head = nn.Linear(
529
+ embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
530
+
531
+ # init weights
532
+ self.apply(self._init_weights)
533
+ self.set_layer_lr_decay(layer_lr_decay)
534
+ self.neck = nn.Sequential(
535
+ nn.Conv2d(
536
+ embed_dims[-1],#handongshen
537
+ 256,
538
+ kernel_size=1,
539
+ bias=False,
540
+ ),
541
+ LayerNorm2d(256),
542
+ nn.Conv2d(
543
+ 256,
544
+ 256,
545
+ kernel_size=3,
546
+ padding=1,
547
+ bias=False,
548
+ ),
549
+ LayerNorm2d(256),
550
+ )
551
+ def set_layer_lr_decay(self, layer_lr_decay):
552
+ decay_rate = layer_lr_decay
553
+
554
+ # layers -> blocks (depth)
555
+ depth = sum(self.depths)
556
+ lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
557
+ print("LR SCALES:", lr_scales)
558
+
559
+ def _set_lr_scale(m, scale):
560
+ for p in m.parameters():
561
+ p.lr_scale = scale
562
+
563
+ self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
564
+ i = 0
565
+ for layer in self.layers:
566
+ for block in layer.blocks:
567
+ block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
568
+ i += 1
569
+ if layer.downsample is not None:
570
+ layer.downsample.apply(
571
+ lambda x: _set_lr_scale(x, lr_scales[i - 1]))
572
+ assert i == depth
573
+ for m in [self.norm_head, self.head]:
574
+ m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
575
+
576
+ for k, p in self.named_parameters():
577
+ p.param_name = k
578
+
579
+ def _check_lr_scale(m):
580
+ for p in m.parameters():
581
+ assert hasattr(p, 'lr_scale'), p.param_name
582
+
583
+ self.apply(_check_lr_scale)
584
+
585
+ def _init_weights(self, m):
586
+ if isinstance(m, nn.Linear):
587
+ trunc_normal_(m.weight, std=.02)
588
+ if isinstance(m, nn.Linear) and m.bias is not None:
589
+ nn.init.constant_(m.bias, 0)
590
+ elif isinstance(m, nn.LayerNorm):
591
+ nn.init.constant_(m.bias, 0)
592
+ nn.init.constant_(m.weight, 1.0)
593
+
594
+ @torch.jit.ignore
595
+ def no_weight_decay_keywords(self):
596
+ return {'attention_biases'}
597
+
598
+ def forward_features(self, x):
599
+ # x: (N, C, H, W)
600
+ x = self.patch_embed(x)
601
+
602
+ x = self.layers[0](x)
603
+ start_i = 1
604
+
605
+ for i in range(start_i, len(self.layers)):
606
+ layer = self.layers[i]
607
+ x = layer(x)
608
+ B,_,C=x.size()
609
+ x = x.view(B, 64, 64, C)
610
+ x=x.permute(0, 3, 1, 2)
611
+ x=self.neck(x)
612
+ return x
613
+
614
+ def forward(self, x):
615
+ x = self.forward_features(x)
616
+
617
+ # We have made some hack changes here to make it compatible with SAM-HQ
618
+ return x, None
619
+
620
+
621
+ _checkpoint_url_format = \
622
+ 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth'
623
+ _provided_checkpoints = {
624
+ 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill',
625
+ 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill',
626
+ 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill',
627
+ 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill',
628
+ 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill',
629
+ }
630
+
631
+
632
+ def register_tiny_vit_model(fn):
633
+ '''Register a TinyViT model
634
+ It is a wrapper of `register_model` with loading the pretrained checkpoint.
635
+ '''
636
+ def fn_wrapper(pretrained=False, **kwargs):
637
+ model = fn()
638
+ if pretrained:
639
+ model_name = fn.__name__
640
+ assert model_name in _provided_checkpoints, \
641
+ f'Sorry that the checkpoint `{model_name}` is not provided yet.'
642
+ url = _checkpoint_url_format.format(
643
+ _provided_checkpoints[model_name])
644
+ checkpoint = torch.hub.load_state_dict_from_url(
645
+ url=url,
646
+ map_location='cpu', check_hash=False,
647
+ )
648
+ model.load_state_dict(checkpoint['model'])
649
+
650
+ return model
651
+
652
+ # rename the name of fn_wrapper
653
+ fn_wrapper.__name__ = fn.__name__
654
+ return register_model(fn_wrapper)
655
+
656
+
657
+ @register_tiny_vit_model
658
+ def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0):
659
+ return TinyViT(
660
+ num_classes=num_classes,
661
+ embed_dims=[64, 128, 160, 320],
662
+ depths=[2, 2, 6, 2],
663
+ num_heads=[2, 4, 5, 10],
664
+ window_sizes=[7, 7, 14, 7],
665
+ drop_path_rate=drop_path_rate,
666
+ )
667
+
668
+
669
+ @register_tiny_vit_model
670
+ def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1):
671
+ return TinyViT(
672
+ num_classes=num_classes,
673
+ embed_dims=[64, 128, 256, 448],
674
+ depths=[2, 2, 6, 2],
675
+ num_heads=[2, 4, 8, 14],
676
+ window_sizes=[7, 7, 14, 7],
677
+ drop_path_rate=drop_path_rate,
678
+ )
679
+
680
+
681
+ @register_tiny_vit_model
682
+ def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2):
683
+ return TinyViT(
684
+ num_classes=num_classes,
685
+ embed_dims=[96, 192, 384, 576],
686
+ depths=[2, 2, 6, 2],
687
+ num_heads=[3, 6, 12, 18],
688
+ window_sizes=[7, 7, 14, 7],
689
+ drop_path_rate=drop_path_rate,
690
+ )
691
+
692
+
693
+ @register_tiny_vit_model
694
+ def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1):
695
+ return TinyViT(
696
+ img_size=384,
697
+ num_classes=num_classes,
698
+ embed_dims=[96, 192, 384, 576],
699
+ depths=[2, 2, 6, 2],
700
+ num_heads=[3, 6, 12, 18],
701
+ window_sizes=[12, 12, 24, 12],
702
+ drop_path_rate=drop_path_rate,
703
+ )
704
+
705
+
706
+ @register_tiny_vit_model
707
+ def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1):
708
+ return TinyViT(
709
+ img_size=512,
710
+ num_classes=num_classes,
711
+ embed_dims=[96, 192, 384, 576],
712
+ depths=[2, 2, 6, 2],
713
+ num_heads=[3, 6, 12, 18],
714
+ window_sizes=[16, 16, 32, 16],
715
+ drop_path_rate=drop_path_rate,
716
+ )
EfficientSAM/README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Efficient Grounded-SAM
2
+
3
+ We're going to combine [Grounding-DINO](https://github.com/IDEA-Research/GroundingDINO) with efficient SAM variants for faster annotating.
4
+
5
+ <!-- Combining [Grounding-DINO](https://github.com/IDEA-Research/GroundingDINO) and [Fast-SAM](https://github.com/CASIA-IVA-Lab/FastSAM) for faster zero-shot detect and segment anything. -->
6
+
7
+
8
+ ### Table of Contents
9
+ - [Installation](#installation)
10
+ - [Efficient SAM Series](#efficient-sams)
11
+ - [Run Grounded-FastSAM Demo](#run-grounded-fastsam-demo)
12
+ - [Run Grounded-MobileSAM Demo](#run-grounded-mobilesam-demo)
13
+ - [Run Grounded-LightHQSAM Demo](#run-grounded-light-hqsam-demo)
14
+ - [Run Grounded-Efficient-SAM Demo](#run-grounded-efficient-sam-demo)
15
+ - [Run Grounded-Edge-SAM Demo](#run-grounded-edge-sam-demo)
16
+ - [Run Grounded-RepViT-SAM Demo](#run-grounded-repvit-sam-demo)
17
+
18
+
19
+ ### Installation
20
+
21
+ - Install [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything#installation)
22
+
23
+ - Install [Fast-SAM](https://github.com/CASIA-IVA-Lab/FastSAM#installation)
24
+
25
+ - Note that we may use the sam image as the demo image in order to compare the inference results of different efficient-sam variants.
26
+
27
+ ### Efficient SAMs
28
+ Here's the list of Efficient SAM variants:
29
+
30
+ <div align="center">
31
+
32
+ | Title | Intro | Description | Links |
33
+ |:----:|:----:|:----:|:----:|
34
+ | [FastSAM](https://arxiv.org/pdf/2306.12156.pdf) | ![](https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/assets/Overview.png) | The Fast Segment Anything Model(FastSAM) is a CNN Segment Anything Model trained by only 2% of the SA-1B dataset published by SAM authors. The FastSAM achieve a comparable performance with the SAM method at 50× higher run-time speed. | [[Github](https://github.com/CASIA-IVA-Lab/FastSAM)] [[Demo](https://huggingface.co/spaces/An-619/FastSAM)] |
35
+ | [MobileSAM](https://arxiv.org/pdf/2306.14289.pdf) | ![](https://github.com/ChaoningZhang/MobileSAM/blob/master/assets/model_diagram.jpg?raw=true) | MobileSAM performs on par with the original SAM (at least visually) and keeps exactly the same pipeline as the original SAM except for a change on the image encoder. Specifically, we replace the original heavyweight ViT-H encoder (632M) with a much smaller Tiny-ViT (5M). On a single GPU, MobileSAM runs around 12ms per image: 8ms on the image encoder and 4ms on the mask decoder. | [[Github](https://github.com/ChaoningZhang/MobileSAM)] |
36
+ | [Light-HQSAM](https://arxiv.org/pdf/2306.01567.pdf) | ![](https://github.com/SysCV/sam-hq/blob/main/figs/sam-hf-framework.png?raw=true) | Light HQ-SAM is based on the tiny vit image encoder provided by MobileSAM. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with ViT features for improved mask details. Refer to [Light HQ-SAM vs. MobileSAM](https://github.com/SysCV/sam-hq#light-hq-sam-vs-mobilesam-on-coco) for more details. | [[Github](https://github.com/SysCV/sam-hq)] |
37
+ | [Efficient-SAM](https://github.com/yformer/EfficientSAM) | ![](https://yformer.github.io/efficient-sam/EfficientSAM_files/overview.png) |Segment Anything Model (SAM) has emerged as a powerful tool for numerous vision applications. However, the huge computation cost of SAM model has limited its applications to wider real-world applications. To address this limitation, we propose EfficientSAMs, light-weight SAM models that exhibit decent performance with largely reduced complexity. Our idea is based on leveraging masked image pretraining, SAMI, which learns to reconstruct features from SAM image encoder for effective visual representation learning. Further, we take SAMI-pretrained light-weight image encoders and mask decoder to build EfficientSAMs, and finetune the models on SA-1B for segment anything task. Refer to [EfficientSAM arXiv](https://arxiv.org/pdf/2312.00863.pdf) for more details.| [[Github](https://github.com/yformer/EfficientSAM)] |
38
+ | [Edge-SAM](https://github.com/chongzhou96/EdgeSAM) | ![](https://www.mmlab-ntu.com/project/edgesam/img/arch.png) | EdgeSAM involves distilling the original ViT-based SAM image encoder into a purely CNN-based architecture, better suited for edge devices. We carefully benchmark various distillation strategies and demonstrate that task-agnostic encoder distillation fails to capture the full knowledge embodied in SAM. Refer to [Edge-SAM arXiv](https://arxiv.org/abs/2312.06660) for more details. | [[Github](https://github.com/chongzhou96/EdgeSAM)] |
39
+ | [RepViT-SAM](https://github.com/THU-MIG/RepViT/tree/main/sam) | ![](https://jameslahm.github.io/repvit-sam/static/images/edge.png) | Recently, RepViT achieves the state-of-the-art performance and latency trade-off on mobile devices by incorporating efficient architectural designs of ViTs into CNNs. Here, to achieve real-time segmenting anything on mobile devices, following MobileSAM, we replace the heavyweight image encoder in SAM with RepViT model, ending up with the RepViT-SAM model. Extensive experiments show that RepViT-SAM can enjoy significantly better zero-shot transfer capability than MobileSAM, along with nearly 10× faster inference speed. Refer to [RepViT-SAM arXiv](https://arxiv.org/pdf/2312.05760.pdf) for more details. | [[Github](https://github.com/THU-MIG/RepViT)] |
40
+
41
+ </div>
42
+
43
+
44
+ ### Run Grounded-FastSAM Demo
45
+
46
+ - Firstly, download the pretrained Fast-SAM weight [here](https://github.com/CASIA-IVA-Lab/FastSAM#model-checkpoints)
47
+
48
+ - Run the demo with the following script:
49
+
50
+ ```bash
51
+ cd Grounded-Segment-Anything
52
+
53
+ python EfficientSAM/grounded_fast_sam.py --model_path "./FastSAM-x.pt" --img_path "assets/demo4.jpg" --text "the black dog." --output "./output/"
54
+ ```
55
+
56
+ - And the results will be saved in `./output/` as:
57
+
58
+ <div style="text-align: center">
59
+
60
+ | Input | Text | Output |
61
+ |:---:|:---:|:---:|
62
+ |![](/assets/demo4.jpg) | "The black dog." | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/fast_sam/demo4_0_caption_the%20black%20dog.jpg?raw=true) |
63
+
64
+ </div>
65
+
66
+
67
+ **Note**: Due to the post process of FastSAM, only one box can be annotated at a time, if there're multiple box prompts, we simply save multiple annotate images to `./output` now, which will be modified in the future release.
68
+
69
+
70
+ ### Run Grounded-MobileSAM Demo
71
+
72
+ - Firstly, download the pretrained MobileSAM weight [here](https://github.com/ChaoningZhang/MobileSAM/tree/master/weights)
73
+
74
+ - Run the demo with the following script:
75
+
76
+ ```bash
77
+ cd Grounded-Segment-Anything
78
+
79
+ python EfficientSAM/grounded_mobile_sam.py --MOBILE_SAM_CHECKPOINT_PATH "./EfficientSAM/mobile_sam.pt" --SOURCE_IMAGE_PATH "./assets/demo2.jpg" --CAPTION "the running dog"
80
+ ```
81
+
82
+ - And the result will be saved as `./gronded_mobile_sam_anontated_image.jpg` as:
83
+
84
+ <div style="text-align: center">
85
+
86
+ | Input | Text | Output |
87
+ |:---:|:---:|:---:|
88
+ |![](/assets/demo2.jpg) | "the running dog" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/mobile_sam/grounded_mobile_sam_annotated_image.jpg?raw=true) |
89
+
90
+ </div>
91
+
92
+
93
+ ### Run Grounded-Light-HQSAM Demo
94
+
95
+ - Firstly, download the pretrained Light-HQSAM weight [here](https://github.com/SysCV/sam-hq#model-checkpoints)
96
+
97
+ - Run the demo with the following script:
98
+
99
+ ```bash
100
+ cd Grounded-Segment-Anything
101
+
102
+ python EfficientSAM/grounded_light_hqsam.py
103
+ ```
104
+
105
+ - And the result will be saved as `./gronded_light_hqsam_anontated_image.jpg` as:
106
+
107
+ <div style="text-align: center">
108
+
109
+ | Input | Text | Output |
110
+ |:---:|:---:|:---:|
111
+ |![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](/EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg) |
112
+
113
+ </div>
114
+
115
+
116
+ ### Run Grounded-Efficient-SAM Demo
117
+
118
+ - Download the pretrained EfficientSAM checkpoint from [here](https://github.com/yformer/EfficientSAM#model) and put it under `Grounded-Segment-Anything/EfficientSAM`
119
+
120
+ - Run the demo with the following script:
121
+
122
+ ```bash
123
+ cd Grounded-Segment-Anything
124
+
125
+ python EfficientSAM/grounded_efficient_sam.py
126
+ ```
127
+
128
+ - And the result will be saved as `./gronded_efficient_sam_anontated_image.jpg` as:
129
+
130
+ <div style="text-align: center">
131
+
132
+ | Input | Text | Output |
133
+ |:---:|:---:|:---:|
134
+ |![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/efficient_sam/grounded_efficient_sam_annotated_image.jpg?raw=true) |
135
+
136
+ </div>
137
+
138
+
139
+ ### Run Grounded-Edge-SAM Demo
140
+
141
+ - Download the pretrained [Edge-SAM](https://github.com/chongzhou96/EdgeSAM) checkpoint follow the [official instruction](https://github.com/chongzhou96/EdgeSAM?tab=readme-ov-file#usage-) as:
142
+
143
+ ```bash
144
+ cd Grounded-Segment-Anything
145
+ wget -P EfficientSAM/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam.pth
146
+ wget -P EfficientSAM/ https://huggingface.co/spaces/chongzhou/EdgeSAM/resolve/main/weights/edge_sam_3x.pth
147
+ ```
148
+
149
+ - Run the demo with the following script:
150
+
151
+ ```bash
152
+ cd Grounded-Segment-Anything
153
+
154
+ python EfficientSAM/grounded_edge_sam.py
155
+ ```
156
+
157
+ - And the result will be saved as `./gronded_edge_sam_anontated_image.jpg` as:
158
+
159
+ <div style="text-align: center">
160
+
161
+ | Input | Text | Output |
162
+ |:---:|:---:|:---:|
163
+ |![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/edge_sam/grounded_edge_sam_annotated_image.jpg?raw=true) |
164
+
165
+ </div>
166
+
167
+ ### Run Grounded-RepViT-SAM Demo
168
+
169
+ - Download the pretrained [RepViT-SAM](https://github.com/THU-MIG/RepViT) checkpoint follow the [official instruction](https://github.com/THU-MIG/RepViT/tree/main/sam#installation) as:
170
+
171
+ ```bash
172
+ cd Grounded-Segment-Anything
173
+ wget -P EfficientSAM/ https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_sam.pt
174
+ ```
175
+
176
+ - Run the demo with the following script:
177
+
178
+ ```bash
179
+ cd Grounded-Segment-Anything
180
+
181
+ python EfficientSAM/grounded_repvit_sam.py
182
+ ```
183
+
184
+ - And the result will be saved as `./gronded_repvit_sam_anontated_image.jpg` as:
185
+
186
+ <div style="text-align: center">
187
+
188
+ | Input | Text | Output |
189
+ |:---:|:---:|:---:|
190
+ |![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "bench" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/repvit_sam/grounded_repvit_sam_annotated_image.jpg?raw=true) |
191
+
192
+ </div>
193
+
194
+
EfficientSAM/RepViTSAM/repvit.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ __all__ = ['repvit_m1']
5
+
6
+
7
+ def _make_divisible(v, divisor, min_value=None):
8
+ """
9
+ This function is taken from the original tf repo.
10
+ It ensures that all layers have a channel number that is divisible by 8
11
+ It can be seen here:
12
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
13
+ :param v:
14
+ :param divisor:
15
+ :param min_value:
16
+ :return:
17
+ """
18
+ if min_value is None:
19
+ min_value = divisor
20
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
21
+ # Make sure that round down does not go down by more than 10%.
22
+ if new_v < 0.9 * v:
23
+ new_v += divisor
24
+ return new_v
25
+
26
+ from timm.models.layers import SqueezeExcite
27
+
28
+ import torch
29
+
30
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
31
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
32
+ class LayerNorm2d(nn.Module):
33
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(num_channels))
36
+ self.bias = nn.Parameter(torch.zeros(num_channels))
37
+ self.eps = eps
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ u = x.mean(1, keepdim=True)
41
+ s = (x - u).pow(2).mean(1, keepdim=True)
42
+ x = (x - u) / torch.sqrt(s + self.eps)
43
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
44
+ return x
45
+
46
+ class Conv2d_BN(torch.nn.Sequential):
47
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
48
+ groups=1, bn_weight_init=1, resolution=-10000):
49
+ super().__init__()
50
+ self.add_module('c', torch.nn.Conv2d(
51
+ a, b, ks, stride, pad, dilation, groups, bias=False))
52
+ self.add_module('bn', torch.nn.BatchNorm2d(b))
53
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
54
+ torch.nn.init.constant_(self.bn.bias, 0)
55
+
56
+ @torch.no_grad()
57
+ def fuse(self):
58
+ c, bn = self._modules.values()
59
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
60
+ w = c.weight * w[:, None, None, None]
61
+ b = bn.bias - bn.running_mean * bn.weight / \
62
+ (bn.running_var + bn.eps)**0.5
63
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
64
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
65
+ device=c.weight.device)
66
+ m.weight.data.copy_(w)
67
+ m.bias.data.copy_(b)
68
+ return m
69
+
70
+ class Residual(torch.nn.Module):
71
+ def __init__(self, m, drop=0.):
72
+ super().__init__()
73
+ self.m = m
74
+ self.drop = drop
75
+
76
+ def forward(self, x):
77
+ if self.training and self.drop > 0:
78
+ return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
79
+ device=x.device).ge_(self.drop).div(1 - self.drop).detach()
80
+ else:
81
+ return x + self.m(x)
82
+
83
+ @torch.no_grad()
84
+ def fuse(self):
85
+ if isinstance(self.m, Conv2d_BN):
86
+ m = self.m.fuse()
87
+ assert(m.groups == m.in_channels)
88
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
89
+ identity = torch.nn.functional.pad(identity, [1,1,1,1])
90
+ m.weight += identity.to(m.weight.device)
91
+ return m
92
+ elif isinstance(self.m, torch.nn.Conv2d):
93
+ m = self.m
94
+ assert(m.groups != m.in_channels)
95
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
96
+ identity = torch.nn.functional.pad(identity, [1,1,1,1])
97
+ m.weight += identity.to(m.weight.device)
98
+ return m
99
+ else:
100
+ return self
101
+
102
+
103
+ class RepVGGDW(torch.nn.Module):
104
+ def __init__(self, ed) -> None:
105
+ super().__init__()
106
+ self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
107
+ self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
108
+ self.dim = ed
109
+ self.bn = torch.nn.BatchNorm2d(ed)
110
+
111
+ def forward(self, x):
112
+ return self.bn((self.conv(x) + self.conv1(x)) + x)
113
+
114
+ @torch.no_grad()
115
+ def fuse(self):
116
+ conv = self.conv.fuse()
117
+ conv1 = self.conv1
118
+
119
+ conv_w = conv.weight
120
+ conv_b = conv.bias
121
+ conv1_w = conv1.weight
122
+ conv1_b = conv1.bias
123
+
124
+ conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
125
+
126
+ identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
127
+
128
+ final_conv_w = conv_w + conv1_w + identity
129
+ final_conv_b = conv_b + conv1_b
130
+
131
+ conv.weight.data.copy_(final_conv_w)
132
+ conv.bias.data.copy_(final_conv_b)
133
+
134
+ bn = self.bn
135
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
136
+ w = conv.weight * w[:, None, None, None]
137
+ b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
138
+ (bn.running_var + bn.eps)**0.5
139
+ conv.weight.data.copy_(w)
140
+ conv.bias.data.copy_(b)
141
+ return conv
142
+
143
+
144
+ class RepViTBlock(nn.Module):
145
+ def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
146
+ super(RepViTBlock, self).__init__()
147
+ assert stride in [1, 2]
148
+
149
+ self.identity = stride == 1 and inp == oup
150
+ assert(hidden_dim == 2 * inp)
151
+
152
+ if stride == 2:
153
+ self.token_mixer = nn.Sequential(
154
+ Conv2d_BN(inp, inp, kernel_size, stride if inp != 320 else 1, (kernel_size - 1) // 2, groups=inp),
155
+ SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
156
+ Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
157
+ )
158
+ self.channel_mixer = Residual(nn.Sequential(
159
+ # pw
160
+ Conv2d_BN(oup, 2 * oup, 1, 1, 0),
161
+ nn.GELU() if use_hs else nn.GELU(),
162
+ # pw-linear
163
+ Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
164
+ ))
165
+ else:
166
+ # assert(self.identity)
167
+ self.token_mixer = nn.Sequential(
168
+ RepVGGDW(inp),
169
+ SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
170
+ )
171
+ if self.identity:
172
+ self.channel_mixer = Residual(nn.Sequential(
173
+ # pw
174
+ Conv2d_BN(inp, hidden_dim, 1, 1, 0),
175
+ nn.GELU() if use_hs else nn.GELU(),
176
+ # pw-linear
177
+ Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
178
+ ))
179
+ else:
180
+ self.channel_mixer = nn.Sequential(
181
+ # pw
182
+ Conv2d_BN(inp, hidden_dim, 1, 1, 0),
183
+ nn.GELU() if use_hs else nn.GELU(),
184
+ # pw-linear
185
+ Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
186
+ )
187
+
188
+ def forward(self, x):
189
+ return self.channel_mixer(self.token_mixer(x))
190
+
191
+ from timm.models.vision_transformer import trunc_normal_
192
+ class BN_Linear(torch.nn.Sequential):
193
+ def __init__(self, a, b, bias=True, std=0.02):
194
+ super().__init__()
195
+ self.add_module('bn', torch.nn.BatchNorm1d(a))
196
+ self.add_module('l', torch.nn.Linear(a, b, bias=bias))
197
+ trunc_normal_(self.l.weight, std=std)
198
+ if bias:
199
+ torch.nn.init.constant_(self.l.bias, 0)
200
+
201
+ @torch.no_grad()
202
+ def fuse(self):
203
+ bn, l = self._modules.values()
204
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
205
+ b = bn.bias - self.bn.running_mean * \
206
+ self.bn.weight / (bn.running_var + bn.eps)**0.5
207
+ w = l.weight * w[None, :]
208
+ if l.bias is None:
209
+ b = b @ self.l.weight.T
210
+ else:
211
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
212
+ m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
213
+ m.weight.data.copy_(w)
214
+ m.bias.data.copy_(b)
215
+ return m
216
+
217
+ class Classfier(nn.Module):
218
+ def __init__(self, dim, num_classes, distillation=True):
219
+ super().__init__()
220
+ self.classifier = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
221
+ self.distillation = distillation
222
+ if distillation:
223
+ self.classifier_dist = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
224
+
225
+ def forward(self, x):
226
+ if self.distillation:
227
+ x = self.classifier(x), self.classifier_dist(x)
228
+ if not self.training:
229
+ x = (x[0] + x[1]) / 2
230
+ else:
231
+ x = self.classifier(x)
232
+ return x
233
+
234
+ @torch.no_grad()
235
+ def fuse(self):
236
+ classifier = self.classifier.fuse()
237
+ if self.distillation:
238
+ classifier_dist = self.classifier_dist.fuse()
239
+ classifier.weight += classifier_dist.weight
240
+ classifier.bias += classifier_dist.bias
241
+ classifier.weight /= 2
242
+ classifier.bias /= 2
243
+ return classifier
244
+ else:
245
+ return classifier
246
+
247
+ class RepViT(nn.Module):
248
+ def __init__(self, cfgs, num_classes=1000, distillation=False, img_size=1024):
249
+ super(RepViT, self).__init__()
250
+ # setting of inverted residual blocks
251
+ self.cfgs = cfgs
252
+
253
+ self.img_size = img_size
254
+
255
+ # building first layer
256
+ input_channel = self.cfgs[0][2]
257
+ patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
258
+ Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
259
+ layers = [patch_embed]
260
+ # building inverted residual blocks
261
+ block = RepViTBlock
262
+ for k, t, c, use_se, use_hs, s in self.cfgs:
263
+ output_channel = _make_divisible(c, 8)
264
+ exp_size = _make_divisible(input_channel * t, 8)
265
+ layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
266
+ input_channel = output_channel
267
+ self.features = nn.ModuleList(layers)
268
+ # self.classifier = Classfier(output_channel, num_classes, distillation)
269
+
270
+ self.neck = nn.Sequential(
271
+ nn.Conv2d(
272
+ output_channel,
273
+ 256,
274
+ kernel_size=1,
275
+ bias=False,
276
+ ),
277
+ LayerNorm2d(256),
278
+ nn.Conv2d(
279
+ 256,
280
+ 256,
281
+ kernel_size=3,
282
+ padding=1,
283
+ bias=False,
284
+ ),
285
+ LayerNorm2d(256),
286
+ )
287
+
288
+ def forward(self, x):
289
+ # x = self.features(x)
290
+ for f in self.features:
291
+ x = f(x)
292
+ # x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
293
+ x = self.neck(x)
294
+ return x, None
295
+
296
+ from timm.models import register_model
297
+
298
+ @register_model
299
+ def repvit(pretrained=False, num_classes = 1000, distillation=False, **kwargs):
300
+ """
301
+ Constructs a MobileNetV3-Large model
302
+ """
303
+ cfgs = [
304
+ # k, t, c, SE, HS, s
305
+ [3, 2, 80, 1, 0, 1],
306
+ [3, 2, 80, 0, 0, 1],
307
+ [3, 2, 80, 1, 0, 1],
308
+ [3, 2, 80, 0, 0, 1],
309
+ [3, 2, 80, 1, 0, 1],
310
+ [3, 2, 80, 0, 0, 1],
311
+ [3, 2, 80, 0, 0, 1],
312
+ [3, 2, 160, 0, 0, 2],
313
+ [3, 2, 160, 1, 0, 1],
314
+ [3, 2, 160, 0, 0, 1],
315
+ [3, 2, 160, 1, 0, 1],
316
+ [3, 2, 160, 0, 0, 1],
317
+ [3, 2, 160, 1, 0, 1],
318
+ [3, 2, 160, 0, 0, 1],
319
+ [3, 2, 160, 0, 0, 1],
320
+ [3, 2, 320, 0, 1, 2],
321
+ [3, 2, 320, 1, 1, 1],
322
+ [3, 2, 320, 0, 1, 1],
323
+ [3, 2, 320, 1, 1, 1],
324
+ [3, 2, 320, 0, 1, 1],
325
+ [3, 2, 320, 1, 1, 1],
326
+ [3, 2, 320, 0, 1, 1],
327
+ [3, 2, 320, 1, 1, 1],
328
+ [3, 2, 320, 0, 1, 1],
329
+ [3, 2, 320, 1, 1, 1],
330
+ [3, 2, 320, 0, 1, 1],
331
+ [3, 2, 320, 1, 1, 1],
332
+ [3, 2, 320, 0, 1, 1],
333
+ [3, 2, 320, 1, 1, 1],
334
+ [3, 2, 320, 0, 1, 1],
335
+ [3, 2, 320, 1, 1, 1],
336
+ [3, 2, 320, 0, 1, 1],
337
+ [3, 2, 320, 1, 1, 1],
338
+ [3, 2, 320, 0, 1, 1],
339
+ [3, 2, 320, 1, 1, 1],
340
+ [3, 2, 320, 0, 1, 1],
341
+ [3, 2, 320, 1, 1, 1],
342
+ [3, 2, 320, 0, 1, 1],
343
+ [3, 2, 320, 1, 1, 1],
344
+ [3, 2, 320, 0, 1, 1],
345
+ [3, 2, 320, 1, 1, 1],
346
+ [3, 2, 320, 0, 1, 1],
347
+ [3, 2, 320, 1, 1, 1],
348
+ [3, 2, 320, 0, 1, 1],
349
+ [3, 2, 320, 1, 1, 1],
350
+ [3, 2, 320, 0, 1, 1],
351
+ [3, 2, 320, 1, 1, 1],
352
+ [3, 2, 320, 0, 1, 1],
353
+ [3, 2, 320, 1, 1, 1],
354
+ [3, 2, 320, 0, 1, 1],
355
+ # [3, 2, 320, 1, 1, 1],
356
+ # [3, 2, 320, 0, 1, 1],
357
+ [3, 2, 320, 0, 1, 1],
358
+ [3, 2, 640, 0, 1, 2],
359
+ [3, 2, 640, 1, 1, 1],
360
+ [3, 2, 640, 0, 1, 1],
361
+ # [3, 2, 640, 1, 1, 1],
362
+ # [3, 2, 640, 0, 1, 1]
363
+ ]
364
+ return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
EfficientSAM/RepViTSAM/setup_repvit_sam.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from functools import partial
9
+ from segment_anything.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
10
+ from RepViTSAM import repvit
11
+ from timm.models import create_model
12
+
13
+ def build_sam_repvit(checkpoint=None):
14
+ prompt_embed_dim = 256
15
+ image_size = 1024
16
+ vit_patch_size = 16
17
+ image_embedding_size = image_size // vit_patch_size
18
+ repvit_sam = Sam(
19
+ image_encoder=create_model('repvit'),
20
+ prompt_encoder=PromptEncoder(
21
+ embed_dim=prompt_embed_dim,
22
+ image_embedding_size=(image_embedding_size, image_embedding_size),
23
+ input_image_size=(image_size, image_size),
24
+ mask_in_chans=16,
25
+ ),
26
+ mask_decoder=MaskDecoder(
27
+ num_multimask_outputs=3,
28
+ transformer=TwoWayTransformer(
29
+ depth=2,
30
+ embedding_dim=prompt_embed_dim,
31
+ mlp_dim=2048,
32
+ num_heads=8,
33
+ ),
34
+ transformer_dim=prompt_embed_dim,
35
+ iou_head_depth=3,
36
+ iou_head_hidden_dim=256,
37
+ ),
38
+ pixel_mean=[123.675, 116.28, 103.53],
39
+ pixel_std=[58.395, 57.12, 57.375],
40
+ )
41
+
42
+ repvit_sam.eval()
43
+ if checkpoint is not None:
44
+ with open(checkpoint, "rb") as f:
45
+ state_dict = torch.load(f)
46
+ repvit_sam.load_state_dict(state_dict)
47
+ return repvit_sam
48
+
49
+ from functools import partial
50
+
51
+ sam_model_registry = {
52
+ "repvit": partial(build_sam_repvit),
53
+ }
EfficientSAM/grounded_edge_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import supervision as sv
4
+
5
+ import torch
6
+ import torchvision
7
+
8
+ from groundingdino.util.inference import Model
9
+ from segment_anything import SamPredictor
10
+ from EdgeSAM.setup_edge_sam import build_edge_sam
11
+
12
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ # GroundingDINO config and checkpoint
15
+ GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
16
+ GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
17
+
18
+ # Building GroundingDINO inference model
19
+ grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
20
+
21
+ # Building MobileSAM predictor
22
+ EdgeSAM_CHECKPOINT_PATH = "./EfficientSAM/edge_sam_3x.pth"
23
+ edge_sam = build_edge_sam(checkpoint=EdgeSAM_CHECKPOINT_PATH)
24
+ edge_sam.to(device=DEVICE)
25
+
26
+ sam_predictor = SamPredictor(edge_sam)
27
+
28
+
29
+ # Predict classes and hyper-param for GroundingDINO
30
+ SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
31
+ CLASSES = ["bench"]
32
+ BOX_THRESHOLD = 0.25
33
+ TEXT_THRESHOLD = 0.25
34
+ NMS_THRESHOLD = 0.8
35
+
36
+
37
+ # load image
38
+ image = cv2.imread(SOURCE_IMAGE_PATH)
39
+
40
+ # detect objects
41
+ detections = grounding_dino_model.predict_with_classes(
42
+ image=image,
43
+ classes=CLASSES,
44
+ box_threshold=BOX_THRESHOLD,
45
+ text_threshold=TEXT_THRESHOLD
46
+ )
47
+
48
+ # annotate image with detections
49
+ box_annotator = sv.BoxAnnotator()
50
+ labels = [
51
+ f"{CLASSES[class_id]} {confidence:0.2f}"
52
+ for _, _, confidence, class_id, _, _
53
+ in detections]
54
+ annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
55
+
56
+ # save the annotated grounding dino image
57
+ cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
58
+
59
+
60
+ # NMS post process
61
+ print(f"Before NMS: {len(detections.xyxy)} boxes")
62
+ nms_idx = torchvision.ops.nms(
63
+ torch.from_numpy(detections.xyxy),
64
+ torch.from_numpy(detections.confidence),
65
+ NMS_THRESHOLD
66
+ ).numpy().tolist()
67
+
68
+ detections.xyxy = detections.xyxy[nms_idx]
69
+ detections.confidence = detections.confidence[nms_idx]
70
+ detections.class_id = detections.class_id[nms_idx]
71
+
72
+ print(f"After NMS: {len(detections.xyxy)} boxes")
73
+
74
+ # Prompting SAM with detected boxes
75
+ def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
76
+ sam_predictor.set_image(image)
77
+ result_masks = []
78
+ for box in xyxy:
79
+ masks, scores, logits = sam_predictor.predict(
80
+ box=box,
81
+ multimask_output=False,
82
+ hq_token_only=True,
83
+ )
84
+ index = np.argmax(scores)
85
+ result_masks.append(masks[index])
86
+ return np.array(result_masks)
87
+
88
+
89
+ # convert detections to masks
90
+ detections.mask = segment(
91
+ sam_predictor=sam_predictor,
92
+ image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
93
+ xyxy=detections.xyxy
94
+ )
95
+
96
+ # annotate image with detections
97
+ box_annotator = sv.BoxAnnotator()
98
+ mask_annotator = sv.MaskAnnotator()
99
+ labels = [
100
+ f"{CLASSES[class_id]} {confidence:0.2f}"
101
+ for _, _, confidence, class_id, _, _
102
+ in detections]
103
+ annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
104
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
105
+
106
+ # save the annotated grounded-sam image
107
+ cv2.imwrite("EfficientSAM/grounded_edge_sam_annotated_image.jpg", annotated_image)
EfficientSAM/grounded_efficient_sam.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import supervision as sv
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchvision.transforms import ToTensor
8
+
9
+ from groundingdino.util.inference import Model
10
+
11
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ # GroundingDINO config and checkpoint
14
+ GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
15
+ GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
16
+
17
+ # Building GroundingDINO inference model
18
+ grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
19
+
20
+ # Building MobileSAM predictor
21
+ EFFICIENT_SAM_CHECHPOINT_PATH = "./EfficientSAM/efficientsam_s_gpu.jit"
22
+ efficientsam = torch.jit.load(EFFICIENT_SAM_CHECHPOINT_PATH)
23
+
24
+
25
+ # Predict classes and hyper-param for GroundingDINO
26
+ SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
27
+ CLASSES = ["bench"]
28
+ BOX_THRESHOLD = 0.25
29
+ TEXT_THRESHOLD = 0.25
30
+ NMS_THRESHOLD = 0.8
31
+
32
+
33
+ # load image
34
+ image = cv2.imread(SOURCE_IMAGE_PATH)
35
+
36
+ # detect objects
37
+ detections = grounding_dino_model.predict_with_classes(
38
+ image=image,
39
+ classes=CLASSES,
40
+ box_threshold=BOX_THRESHOLD,
41
+ text_threshold=TEXT_THRESHOLD
42
+ )
43
+
44
+ # annotate image with detections
45
+ box_annotator = sv.BoxAnnotator()
46
+ labels = [
47
+ f"{CLASSES[class_id]} {confidence:0.2f}"
48
+ for _, _, confidence, class_id, _, _
49
+ in detections]
50
+ annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
51
+
52
+ # save the annotated grounding dino image
53
+ cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
54
+
55
+
56
+ # NMS post process
57
+ print(f"Before NMS: {len(detections.xyxy)} boxes")
58
+ nms_idx = torchvision.ops.nms(
59
+ torch.from_numpy(detections.xyxy),
60
+ torch.from_numpy(detections.confidence),
61
+ NMS_THRESHOLD
62
+ ).numpy().tolist()
63
+
64
+ detections.xyxy = detections.xyxy[nms_idx]
65
+ detections.confidence = detections.confidence[nms_idx]
66
+ detections.class_id = detections.class_id[nms_idx]
67
+
68
+ print(f"After NMS: {len(detections.xyxy)} boxes")
69
+
70
+
71
+ def efficient_sam_box_prompt_segment(image, pts_sampled, model):
72
+ bbox = torch.reshape(torch.tensor(pts_sampled), [1, 1, 2, 2])
73
+ bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
74
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
+ img_tensor = ToTensor()(image)
76
+
77
+ predicted_logits, predicted_iou = model(
78
+ img_tensor[None, ...].cuda(),
79
+ bbox.cuda(),
80
+ bbox_labels.cuda(),
81
+ )
82
+ predicted_logits = predicted_logits.cpu()
83
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
84
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
85
+
86
+ max_predicted_iou = -1
87
+ selected_mask_using_predicted_iou = None
88
+ for m in range(all_masks.shape[0]):
89
+ curr_predicted_iou = predicted_iou[m]
90
+ if (
91
+ curr_predicted_iou > max_predicted_iou
92
+ or selected_mask_using_predicted_iou is None
93
+ ):
94
+ max_predicted_iou = curr_predicted_iou
95
+ selected_mask_using_predicted_iou = all_masks[m]
96
+ return selected_mask_using_predicted_iou
97
+
98
+
99
+ # collect segment results from EfficientSAM
100
+ result_masks = []
101
+ for box in detections.xyxy:
102
+ mask = efficient_sam_box_prompt_segment(image, box, efficientsam)
103
+ result_masks.append(mask)
104
+
105
+ detections.mask = np.array(result_masks)
106
+
107
+ # annotate image with detections
108
+ box_annotator = sv.BoxAnnotator()
109
+ mask_annotator = sv.MaskAnnotator()
110
+ labels = [
111
+ f"{CLASSES[class_id]} {confidence:0.2f}"
112
+ for _, _, confidence, class_id, _, _
113
+ in detections]
114
+ annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
115
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
116
+
117
+ # save the annotated grounded-sam image
118
+ cv2.imwrite("EfficientSAM/gronded_efficient_sam_anontated_image.jpg", annotated_image)
EfficientSAM/grounded_fast_sam.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ from ultralytics import YOLO
4
+ from FastSAM.tools import *
5
+ from groundingdino.util.inference import load_model, load_image, predict, annotate, Model
6
+ from torchvision.ops import box_convert
7
+ import ast
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "--model_path", type=str, default="./FastSAM/FastSAM-x.pt", help="model"
13
+ )
14
+ parser.add_argument(
15
+ "--img_path", type=str, default="./images/dogs.jpg", help="path to image file"
16
+ )
17
+ parser.add_argument(
18
+ "--text", type=str, default="the black dog.", help="text prompt for GroundingDINO"
19
+ )
20
+ parser.add_argument("--imgsz", type=int, default=1024, help="image size")
21
+ parser.add_argument(
22
+ "--iou",
23
+ type=float,
24
+ default=0.9,
25
+ help="iou threshold for filtering the annotations",
26
+ )
27
+ parser.add_argument(
28
+ "--conf", type=float, default=0.4, help="object confidence threshold"
29
+ )
30
+ parser.add_argument(
31
+ "--output", type=str, default="./output/", help="image save path"
32
+ )
33
+ parser.add_argument(
34
+ "--randomcolor", type=bool, default=True, help="mask random color"
35
+ )
36
+ parser.add_argument(
37
+ "--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]"
38
+ )
39
+ parser.add_argument(
40
+ "--point_label",
41
+ type=str,
42
+ default="[0]",
43
+ help="[1,0] 0:background, 1:foreground",
44
+ )
45
+ parser.add_argument("--box_prompt", type=str, default="[0,0,0,0]", help="[x,y,w,h]")
46
+ parser.add_argument(
47
+ "--better_quality",
48
+ type=str,
49
+ default=False,
50
+ help="better quality using morphologyEx",
51
+ )
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ parser.add_argument(
54
+ "--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
55
+ )
56
+ parser.add_argument(
57
+ "--retina",
58
+ type=bool,
59
+ default=True,
60
+ help="draw high-resolution segmentation masks",
61
+ )
62
+ parser.add_argument(
63
+ "--withContours", type=bool, default=False, help="draw the edges of the masks"
64
+ )
65
+ return parser.parse_args()
66
+
67
+
68
+ def main(args):
69
+
70
+ # Image Path
71
+ img_path = args.img_path
72
+ text = args.text
73
+
74
+ # path to save img
75
+ save_path = args.output
76
+ if not os.path.exists(save_path):
77
+ os.makedirs(save_path)
78
+ basename = os.path.basename(args.img_path).split(".")[0]
79
+
80
+ # Build Fast-SAM Model
81
+ # ckpt_path = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/FastSAM/FastSAM-x.pt"
82
+ model = YOLO(args.model_path)
83
+
84
+ results = model(
85
+ args.img_path,
86
+ imgsz=args.imgsz,
87
+ device=args.device,
88
+ retina_masks=args.retina,
89
+ iou=args.iou,
90
+ conf=args.conf,
91
+ max_det=100,
92
+ )
93
+
94
+
95
+ # Build GroundingDINO Model
96
+ groundingdino_config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
97
+ groundingdino_ckpt_path = "./groundingdino_swint_ogc.pth"
98
+
99
+ image_source, image = load_image(img_path)
100
+ model = load_model(groundingdino_config, groundingdino_ckpt_path)
101
+
102
+ boxes, logits, phrases = predict(
103
+ model=model,
104
+ image=image,
105
+ caption=text,
106
+ box_threshold=0.3,
107
+ text_threshold=0.25,
108
+ device=args.device,
109
+ )
110
+
111
+
112
+ # Grounded-Fast-SAM
113
+
114
+ ori_img = cv2.imread(img_path)
115
+ ori_h = ori_img.shape[0]
116
+ ori_w = ori_img.shape[1]
117
+
118
+ # Save each frame due to the post process from FastSAM
119
+ boxes = boxes * torch.Tensor([ori_w, ori_h, ori_w, ori_h])
120
+ print(f"Detected Boxes: {len(boxes)}")
121
+ boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").cpu().numpy().tolist()
122
+ for box_idx in range(len(boxes)):
123
+ mask, _ = box_prompt(
124
+ results[0].masks.data,
125
+ boxes[box_idx],
126
+ ori_h,
127
+ ori_w,
128
+ )
129
+ annotations = np.array([mask])
130
+ img_array = fast_process(
131
+ annotations=annotations,
132
+ args=args,
133
+ mask_random_color=True,
134
+ bbox=boxes[box_idx],
135
+ )
136
+ cv2.imwrite(os.path.join(save_path, basename + f"_{str(box_idx)}_caption_{phrases[box_idx]}.jpg"), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
137
+
138
+
139
+ if __name__ == "__main__":
140
+ args = parse_args()
141
+ main(args)
EfficientSAM/grounded_light_hqsam.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import supervision as sv
4
+
5
+ import torch
6
+ import torchvision
7
+
8
+ from groundingdino.util.inference import Model
9
+ from segment_anything import SamPredictor
10
+ from LightHQSAM.setup_light_hqsam import setup_model
11
+
12
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ # GroundingDINO config and checkpoint
15
+ GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
16
+ GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
17
+
18
+ # Building GroundingDINO inference model
19
+ grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
20
+
21
+ # Building MobileSAM predictor
22
+ HQSAM_CHECKPOINT_PATH = "./EfficientSAM/sam_hq_vit_tiny.pth"
23
+ checkpoint = torch.load(HQSAM_CHECKPOINT_PATH)
24
+ light_hqsam = setup_model()
25
+ light_hqsam.load_state_dict(checkpoint, strict=True)
26
+ light_hqsam.to(device=DEVICE)
27
+
28
+ sam_predictor = SamPredictor(light_hqsam)
29
+
30
+
31
+ # Predict classes and hyper-param for GroundingDINO
32
+ SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
33
+ CLASSES = ["bench"]
34
+ BOX_THRESHOLD = 0.25
35
+ TEXT_THRESHOLD = 0.25
36
+ NMS_THRESHOLD = 0.8
37
+
38
+
39
+ # load image
40
+ image = cv2.imread(SOURCE_IMAGE_PATH)
41
+
42
+ # detect objects
43
+ detections = grounding_dino_model.predict_with_classes(
44
+ image=image,
45
+ classes=CLASSES,
46
+ box_threshold=BOX_THRESHOLD,
47
+ text_threshold=TEXT_THRESHOLD
48
+ )
49
+
50
+ # annotate image with detections
51
+ box_annotator = sv.BoxAnnotator()
52
+ labels = [
53
+ f"{CLASSES[class_id]} {confidence:0.2f}"
54
+ for _, _, confidence, class_id, _, _
55
+ in detections]
56
+ annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
57
+
58
+ # save the annotated grounding dino image
59
+ cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
60
+
61
+
62
+ # NMS post process
63
+ print(f"Before NMS: {len(detections.xyxy)} boxes")
64
+ nms_idx = torchvision.ops.nms(
65
+ torch.from_numpy(detections.xyxy),
66
+ torch.from_numpy(detections.confidence),
67
+ NMS_THRESHOLD
68
+ ).numpy().tolist()
69
+
70
+ detections.xyxy = detections.xyxy[nms_idx]
71
+ detections.confidence = detections.confidence[nms_idx]
72
+ detections.class_id = detections.class_id[nms_idx]
73
+
74
+ print(f"After NMS: {len(detections.xyxy)} boxes")
75
+
76
+ # Prompting SAM with detected boxes
77
+ def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
78
+ sam_predictor.set_image(image)
79
+ result_masks = []
80
+ for box in xyxy:
81
+ masks, scores, logits = sam_predictor.predict(
82
+ box=box,
83
+ multimask_output=False,
84
+ hq_token_only=True,
85
+ )
86
+ index = np.argmax(scores)
87
+ result_masks.append(masks[index])
88
+ return np.array(result_masks)
89
+
90
+
91
+ # convert detections to masks
92
+ detections.mask = segment(
93
+ sam_predictor=sam_predictor,
94
+ image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
95
+ xyxy=detections.xyxy
96
+ )
97
+
98
+ # annotate image with detections
99
+ box_annotator = sv.BoxAnnotator()
100
+ mask_annotator = sv.MaskAnnotator()
101
+ labels = [
102
+ f"{CLASSES[class_id]} {confidence:0.2f}"
103
+ for _, _, confidence, class_id, _, _
104
+ in detections]
105
+ annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
106
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
107
+
108
+ # save the annotated grounded-sam image
109
+ cv2.imwrite("EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg", annotated_image)
EfficientSAM/grounded_mobile_sam.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import supervision as sv
4
+ import argparse
5
+ import torch
6
+ import torchvision
7
+
8
+ from groundingdino.util.inference import Model
9
+ from segment_anything import SamPredictor
10
+ from MobileSAM.setup_mobile_sam import setup_model
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ "--MOBILE_SAM_CHECKPOINT_PATH", type=str, default="./EfficientSAM/mobile_sam.pt", help="model"
16
+ )
17
+ parser.add_argument(
18
+ "--SOURCE_IMAGE_PATH", type=str, default="./assets/demo2.jpg", help="path to image file"
19
+ )
20
+ parser.add_argument(
21
+ "--CAPTION", type=str, default="The running dog", help="text prompt for GroundingDINO"
22
+ )
23
+ parser.add_argument(
24
+ "--OUT_FILE_BOX", type=str, default="groundingdino_annotated_image.jpg", help="the output filename"
25
+ )
26
+ parser.add_argument(
27
+ "--OUT_FILE_SEG", type=str, default="grounded_mobile_sam_annotated_image.jpg", help="the output filename"
28
+ )
29
+ parser.add_argument(
30
+ "--OUT_FILE_BIN_MASK", type=str, default="grounded_mobile_sam_bin_mask.jpg", help="the output filename"
31
+ )
32
+ parser.add_argument("--BOX_THRESHOLD", type=float, default=0.25, help="")
33
+ parser.add_argument("--TEXT_THRESHOLD", type=float, default=0.25, help="")
34
+ parser.add_argument("--NMS_THRESHOLD", type=float, default=0.8, help="")
35
+
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ parser.add_argument(
38
+ "--DEVICE", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
39
+ )
40
+ return parser.parse_args()
41
+
42
+ def main(args):
43
+ DEVICE = args.DEVICE
44
+
45
+ # GroundingDINO config and checkpoint
46
+ GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
47
+ GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
48
+
49
+ # Building GroundingDINO inference model
50
+ grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
51
+
52
+ # Building MobileSAM predictor
53
+ MOBILE_SAM_CHECKPOINT_PATH = args.MOBILE_SAM_CHECKPOINT_PATH
54
+ checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH)
55
+ mobile_sam = setup_model()
56
+ mobile_sam.load_state_dict(checkpoint, strict=True)
57
+ mobile_sam.to(device=DEVICE)
58
+
59
+ sam_predictor = SamPredictor(mobile_sam)
60
+
61
+
62
+ # Predict classes and hyper-param for GroundingDINO
63
+ SOURCE_IMAGE_PATH = args.SOURCE_IMAGE_PATH
64
+ CLASSES = [args.CAPTION]
65
+ BOX_THRESHOLD = args.BOX_THRESHOLD
66
+ TEXT_THRESHOLD = args.TEXT_THRESHOLD
67
+ NMS_THRESHOLD = args.NMS_THRESHOLD
68
+
69
+
70
+ # load image
71
+ image = cv2.imread(SOURCE_IMAGE_PATH)
72
+
73
+ # detect objects
74
+ detections = grounding_dino_model.predict_with_classes(
75
+ image=image,
76
+ classes=CLASSES,
77
+ box_threshold=BOX_THRESHOLD,
78
+ text_threshold=TEXT_THRESHOLD
79
+ )
80
+
81
+ # annotate image with detections
82
+ box_annotator = sv.BoxAnnotator()
83
+ labels = [
84
+ f"{CLASSES[class_id]} {confidence:0.2f}"
85
+ for _, _, confidence, class_id, _, _
86
+ in detections]
87
+ annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
88
+
89
+ # save the annotated grounding dino image
90
+ cv2.imwrite(args.OUT_FILE_BOX, annotated_frame)
91
+
92
+
93
+ # NMS post process
94
+ print(f"Before NMS: {len(detections.xyxy)} boxes")
95
+ nms_idx = torchvision.ops.nms(
96
+ torch.from_numpy(detections.xyxy),
97
+ torch.from_numpy(detections.confidence),
98
+ NMS_THRESHOLD
99
+ ).numpy().tolist()
100
+
101
+ detections.xyxy = detections.xyxy[nms_idx]
102
+ detections.confidence = detections.confidence[nms_idx]
103
+ detections.class_id = detections.class_id[nms_idx]
104
+
105
+ print(f"After NMS: {len(detections.xyxy)} boxes")
106
+
107
+ # Prompting SAM with detected boxes
108
+ def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
109
+ sam_predictor.set_image(image)
110
+ result_masks = []
111
+ for box in xyxy:
112
+ masks, scores, logits = sam_predictor.predict(
113
+ box=box,
114
+ multimask_output=True
115
+ )
116
+ index = np.argmax(scores)
117
+ result_masks.append(masks[index])
118
+ return np.array(result_masks)
119
+
120
+
121
+ # convert detections to masks
122
+ detections.mask = segment(
123
+ sam_predictor=sam_predictor,
124
+ image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
125
+ xyxy=detections.xyxy
126
+ )
127
+
128
+ binary_mask = detections.mask[0].astype(np.uint8)*255
129
+ cv2.imwrite(args.OUT_FILE_BIN_MASK, binary_mask)
130
+
131
+ # annotate image with detections
132
+ box_annotator = sv.BoxAnnotator()
133
+ mask_annotator = sv.MaskAnnotator()
134
+ labels = [
135
+ f"{CLASSES[class_id]} {confidence:0.2f}"
136
+ for _, _, confidence, class_id, _, _
137
+ in detections]
138
+ annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
139
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
140
+ # save the annotated grounded-sam image
141
+ cv2.imwrite(args.OUT_FILE_SEG, annotated_image)
142
+
143
+ if __name__ == "__main__":
144
+ args = parse_args()
145
+ main(args)
EfficientSAM/grounded_repvit_sam.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import supervision as sv
4
+
5
+ import torch
6
+ import torchvision
7
+
8
+ from groundingdino.util.inference import Model
9
+ from segment_anything import SamPredictor
10
+ from RepViTSAM.setup_repvit_sam import build_sam_repvit
11
+
12
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ # GroundingDINO config and checkpoint
15
+ GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
16
+ GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
17
+
18
+ # Building GroundingDINO inference model
19
+ grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
20
+
21
+ # Building MobileSAM predictor
22
+ RepViTSAM_CHECKPOINT_PATH = "./EfficientSAM/repvit_sam.pt"
23
+ repvit_sam = build_sam_repvit(checkpoint=RepViTSAM_CHECKPOINT_PATH)
24
+ repvit_sam.to(device=DEVICE)
25
+
26
+ sam_predictor = SamPredictor(repvit_sam)
27
+
28
+
29
+ # Predict classes and hyper-param for GroundingDINO
30
+ SOURCE_IMAGE_PATH = "./EfficientSAM/LightHQSAM/example_light_hqsam.png"
31
+ CLASSES = ["bench"]
32
+ BOX_THRESHOLD = 0.25
33
+ TEXT_THRESHOLD = 0.25
34
+ NMS_THRESHOLD = 0.8
35
+
36
+
37
+ # load image
38
+ image = cv2.imread(SOURCE_IMAGE_PATH)
39
+
40
+ # detect objects
41
+ detections = grounding_dino_model.predict_with_classes(
42
+ image=image,
43
+ classes=CLASSES,
44
+ box_threshold=BOX_THRESHOLD,
45
+ text_threshold=TEXT_THRESHOLD
46
+ )
47
+
48
+ # annotate image with detections
49
+ box_annotator = sv.BoxAnnotator()
50
+ labels = [
51
+ f"{CLASSES[class_id]} {confidence:0.2f}"
52
+ for _, _, confidence, class_id, _, _
53
+ in detections]
54
+ annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
55
+
56
+ # save the annotated grounding dino image
57
+ cv2.imwrite("EfficientSAM/LightHQSAM/groundingdino_annotated_image.jpg", annotated_frame)
58
+
59
+
60
+ # NMS post process
61
+ print(f"Before NMS: {len(detections.xyxy)} boxes")
62
+ nms_idx = torchvision.ops.nms(
63
+ torch.from_numpy(detections.xyxy),
64
+ torch.from_numpy(detections.confidence),
65
+ NMS_THRESHOLD
66
+ ).numpy().tolist()
67
+
68
+ detections.xyxy = detections.xyxy[nms_idx]
69
+ detections.confidence = detections.confidence[nms_idx]
70
+ detections.class_id = detections.class_id[nms_idx]
71
+
72
+ print(f"After NMS: {len(detections.xyxy)} boxes")
73
+
74
+ # Prompting SAM with detected boxes
75
+ def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
76
+ sam_predictor.set_image(image)
77
+ result_masks = []
78
+ for box in xyxy:
79
+ masks, scores, logits = sam_predictor.predict(
80
+ box=box,
81
+ multimask_output=False,
82
+ hq_token_only=True,
83
+ )
84
+ index = np.argmax(scores)
85
+ result_masks.append(masks[index])
86
+ return np.array(result_masks)
87
+
88
+
89
+ # convert detections to masks
90
+ detections.mask = segment(
91
+ sam_predictor=sam_predictor,
92
+ image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
93
+ xyxy=detections.xyxy
94
+ )
95
+
96
+ # annotate image with detections
97
+ box_annotator = sv.BoxAnnotator()
98
+ mask_annotator = sv.MaskAnnotator()
99
+ labels = [
100
+ f"{CLASSES[class_id]} {confidence:0.2f}"
101
+ for _, _, confidence, class_id, _, _
102
+ in detections]
103
+ annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
104
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
105
+
106
+ # save the annotated grounded-sam image
107
+ cv2.imwrite("EfficientSAM/grounded_repvit_sam_annotated_image.jpg", annotated_image)
GroundingDINO/.asset/COCO.png ADDED
GroundingDINO/.asset/GD_GLIGEN.png ADDED

Git LFS Details

  • SHA256: 6e36d497ace68412ecd6c064fff6d7481a685963ffc2ec047a8892411fb0ab8e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
GroundingDINO/.asset/GD_SD.png ADDED

Git LFS Details

  • SHA256: 92c8a690a2de028d42c9b876c73dca53b7736134eb77cce5b3cbda9d1c4b62de
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
GroundingDINO/.asset/ODinW.png ADDED
GroundingDINO/.asset/arch.png ADDED
GroundingDINO/.asset/cats.png ADDED
GroundingDINO/.asset/hero_figure.png ADDED

Git LFS Details

  • SHA256: 24b18b31e9f150bae0ae01b09608d7bf7fc34f42c8e17d85eda55ea4a55b1e91
  • Pointer size: 132 Bytes
  • Size of remote file: 2.98 MB
GroundingDINO/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 - present, Facebook, Inc
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
GroundingDINO/README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Grounding DINO
2
+
3
+ ---
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2303.05499-b31b1b.svg)](https://arxiv.org/abs/2303.05499)
6
+ [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/wxWDt5UiwY8)
7
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
8
+ [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/cMa77r3YrDk)
9
+ [![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)
10
+
11
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
12
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
13
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
14
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
15
+
16
+
17
+
18
+ Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
19
+
20
+
21
+ ## Highlight
22
+
23
+ - **Open-Set Detection.** Detect **everything** with language!
24
+ - **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
25
+ - **Flexible.** Collaboration with Stable Diffusion for Image Editting.
26
+
27
+ ## News
28
+ [2023/03/28] A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)] \
29
+ [2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space! \
30
+ [2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
31
+ [2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)] \
32
+ [2023/03/22] Code is available Now!
33
+
34
+ <details open>
35
+ <summary><font size="4">
36
+ Description
37
+ </font></summary>
38
+ <img src=".asset/hero_figure.png" alt="ODinW" width="100%">
39
+ </details>
40
+
41
+
42
+
43
+ ## TODO
44
+
45
+ - [x] Release inference code and demo.
46
+ - [x] Release checkpoints.
47
+ - [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
48
+ - [ ] Release training codes.
49
+
50
+ ## Install
51
+
52
+ If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
53
+
54
+ ```bash
55
+ pip install -e .
56
+ ```
57
+
58
+ ## Demo
59
+
60
+ ```bash
61
+ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
62
+ -c /path/to/config \
63
+ -p /path/to/checkpoint \
64
+ -i .asset/cats.png \
65
+ -o "outputs/0" \
66
+ -t "cat ear." \
67
+ [--cpu-only] # open it for cpu mode
68
+ ```
69
+ See the `demo/inference_on_a_image.py` for more details.
70
+
71
+ **Web UI**
72
+
73
+ We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details.
74
+
75
+ ## Checkpoints
76
+
77
+ <!-- insert a table -->
78
+ <table>
79
+ <thead>
80
+ <tr style="text-align: right;">
81
+ <th></th>
82
+ <th>name</th>
83
+ <th>backbone</th>
84
+ <th>Data</th>
85
+ <th>box AP on COCO</th>
86
+ <th>Checkpoint</th>
87
+ <th>Config</th>
88
+ </tr>
89
+ </thead>
90
+ <tbody>
91
+ <tr>
92
+ <th>1</th>
93
+ <td>GroundingDINO-T</td>
94
+ <td>Swin-T</td>
95
+ <td>O365,GoldG,Cap4M</td>
96
+ <td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
97
+ <td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">Github link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth">HF link</a></td>
98
+ <td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
99
+ </tr>
100
+ </tbody>
101
+ </table>
102
+
103
+ ## Results
104
+
105
+ <details open>
106
+ <summary><font size="4">
107
+ COCO Object Detection Results
108
+ </font></summary>
109
+ <img src=".asset/COCO.png" alt="COCO" width="100%">
110
+ </details>
111
+
112
+ <details open>
113
+ <summary><font size="4">
114
+ ODinW Object Detection Results
115
+ </font></summary>
116
+ <img src=".asset/ODinW.png" alt="ODinW" width="100%">
117
+ </details>
118
+
119
+ <details open>
120
+ <summary><font size="4">
121
+ Marrying Grounding DINO with <a href="https://github.com/Stability-AI/StableDiffusion">Stable Diffusion</a> for Image Editing
122
+ </font></summary>
123
+ <img src=".asset/GD_SD.png" alt="GD_SD" width="100%">
124
+ </details>
125
+
126
+ <details open>
127
+ <summary><font size="4">
128
+ Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</a> for more Detailed Image Editing
129
+ </font></summary>
130
+ <img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
131
+ </details>
132
+
133
+ ## Model
134
+
135
+ Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
136
+
137
+ ![arch](.asset/arch.png)
138
+
139
+
140
+ ## Acknowledgement
141
+
142
+ Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
143
+
144
+ We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
145
+
146
+ Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
147
+
148
+
149
+ ## Citation
150
+
151
+ If you find our work helpful for your research, please consider citing the following BibTeX entry.
152
+
153
+ ```bibtex
154
+ @inproceedings{ShilongLiu2023GroundingDM,
155
+ title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},
156
+ author={Shilong Liu and Zhaoyang Zeng and Tianhe Ren and Feng Li and Hao Zhang and Jie Yang and Chunyuan Li and Jianwei Yang and Hang Su and Jun Zhu and Lei Zhang},
157
+ year={2023}
158
+ }
159
+ ```
160
+
161
+
162
+
163
+
GroundingDINO/demo/gradio_app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from functools import partial
3
+ import cv2
4
+ import requests
5
+ import os
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import numpy as np
9
+ from pathlib import Path
10
+
11
+
12
+ import warnings
13
+
14
+ import torch
15
+
16
+ # prepare the environment
17
+ os.system("python setup.py build develop --user")
18
+ os.system("pip install packaging==21.3")
19
+ os.system("pip install gradio")
20
+
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+ import gradio as gr
25
+
26
+ from groundingdino.models import build_model
27
+ from groundingdino.util.slconfig import SLConfig
28
+ from groundingdino.util.utils import clean_state_dict
29
+ from groundingdino.util.inference import annotate, load_image, predict
30
+ import groundingdino.datasets.transforms as T
31
+
32
+ from huggingface_hub import hf_hub_download
33
+
34
+
35
+
36
+ # Use this command for evaluate the GLIP-T model
37
+ config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
38
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
39
+ ckpt_filenmae = "groundingdino_swint_ogc.pth"
40
+
41
+
42
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
43
+ args = SLConfig.fromfile(model_config_path)
44
+ model = build_model(args)
45
+ args.device = device
46
+
47
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
48
+ checkpoint = torch.load(cache_file, map_location='cpu')
49
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
50
+ print("Model loaded from {} \n => {}".format(cache_file, log))
51
+ _ = model.eval()
52
+ return model
53
+
54
+ def image_transform_grounding(init_image):
55
+ transform = T.Compose([
56
+ T.RandomResize([800], max_size=1333),
57
+ T.ToTensor(),
58
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
+ ])
60
+ image, _ = transform(init_image, None) # 3, h, w
61
+ return init_image, image
62
+
63
+ def image_transform_grounding_for_vis(init_image):
64
+ transform = T.Compose([
65
+ T.RandomResize([800], max_size=1333),
66
+ ])
67
+ image, _ = transform(init_image, None) # 3, h, w
68
+ return image
69
+
70
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
71
+
72
+ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
73
+ init_image = input_image.convert("RGB")
74
+ original_size = init_image.size
75
+
76
+ _, image_tensor = image_transform_grounding(init_image)
77
+ image_pil: Image = image_transform_grounding_for_vis(init_image)
78
+
79
+ # run grounidng
80
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
81
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
82
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
83
+
84
+
85
+ return image_with_box
86
+
87
+ if __name__ == "__main__":
88
+
89
+ parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
90
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
91
+ parser.add_argument("--share", action="store_true", help="share the app")
92
+ args = parser.parse_args()
93
+
94
+ block = gr.Blocks().queue()
95
+ with block:
96
+ gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
97
+ gr.Markdown("### Open-World Detection with Grounding DINO")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ input_image = gr.Image(source='upload', type="pil")
102
+ grounding_caption = gr.Textbox(label="Detection Prompt")
103
+ run_button = gr.Button(label="Run")
104
+ with gr.Accordion("Advanced options", open=False):
105
+ box_threshold = gr.Slider(
106
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
107
+ )
108
+ text_threshold = gr.Slider(
109
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
110
+ )
111
+
112
+ with gr.Column():
113
+ gallery = gr.outputs.Image(
114
+ type="pil",
115
+ # label="grounding results"
116
+ ).style(full_width=True, full_height=True)
117
+ # gallery = gr.Gallery(label="Generated images", show_label=False).style(
118
+ # grid=[1], height="auto", container=True, full_width=True, full_height=True)
119
+
120
+ run_button.click(fn=run_grounding, inputs=[
121
+ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
122
+
123
+
124
+ block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
125
+
GroundingDINO/demo/inference_on_a_image.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image, ImageDraw, ImageFont
8
+
9
+ import groundingdino.datasets.transforms as T
10
+ from groundingdino.models import build_model
11
+ from groundingdino.util import box_ops
12
+ from groundingdino.util.slconfig import SLConfig
13
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
14
+
15
+
16
+ def plot_boxes_to_image(image_pil, tgt):
17
+ H, W = tgt["size"]
18
+ boxes = tgt["boxes"]
19
+ labels = tgt["labels"]
20
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
21
+
22
+ draw = ImageDraw.Draw(image_pil)
23
+ mask = Image.new("L", image_pil.size, 0)
24
+ mask_draw = ImageDraw.Draw(mask)
25
+
26
+ # draw boxes and masks
27
+ for box, label in zip(boxes, labels):
28
+ # from 0..1 to 0..W, 0..H
29
+ box = box * torch.Tensor([W, H, W, H])
30
+ # from xywh to xyxy
31
+ box[:2] -= box[2:] / 2
32
+ box[2:] += box[:2]
33
+ # random color
34
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
35
+ # draw
36
+ x0, y0, x1, y1 = box
37
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
38
+
39
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
40
+ # draw.text((x0, y0), str(label), fill=color)
41
+
42
+ font = ImageFont.load_default()
43
+ if hasattr(font, "getbbox"):
44
+ bbox = draw.textbbox((x0, y0), str(label), font)
45
+ else:
46
+ w, h = draw.textsize(str(label), font)
47
+ bbox = (x0, y0, w + x0, y0 + h)
48
+ # bbox = draw.textbbox((x0, y0), str(label))
49
+ draw.rectangle(bbox, fill=color)
50
+ draw.text((x0, y0), str(label), fill="white")
51
+
52
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
53
+
54
+ return image_pil, mask
55
+
56
+
57
+ def load_image(image_path):
58
+ # load image
59
+ image_pil = Image.open(image_path).convert("RGB") # load image
60
+
61
+ transform = T.Compose(
62
+ [
63
+ T.RandomResize([800], max_size=1333),
64
+ T.ToTensor(),
65
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
66
+ ]
67
+ )
68
+ image, _ = transform(image_pil, None) # 3, h, w
69
+ return image_pil, image
70
+
71
+
72
+ def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
73
+ args = SLConfig.fromfile(model_config_path)
74
+ args.device = "cuda" if not cpu_only else "cpu"
75
+ model = build_model(args)
76
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
77
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
78
+ print(load_res)
79
+ _ = model.eval()
80
+ return model
81
+
82
+
83
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
84
+ caption = caption.lower()
85
+ caption = caption.strip()
86
+ if not caption.endswith("."):
87
+ caption = caption + "."
88
+ device = "cuda" if not cpu_only else "cpu"
89
+ model = model.to(device)
90
+ image = image.to(device)
91
+ with torch.no_grad():
92
+ outputs = model(image[None], captions=[caption])
93
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
94
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
95
+ logits.shape[0]
96
+
97
+ # filter output
98
+ logits_filt = logits.clone()
99
+ boxes_filt = boxes.clone()
100
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
101
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
102
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
103
+ logits_filt.shape[0]
104
+
105
+ # get phrase
106
+ tokenlizer = model.tokenizer
107
+ tokenized = tokenlizer(caption)
108
+ # build pred
109
+ pred_phrases = []
110
+ for logit, box in zip(logits_filt, boxes_filt):
111
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
112
+ if with_logits:
113
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
114
+ else:
115
+ pred_phrases.append(pred_phrase)
116
+
117
+ return boxes_filt, pred_phrases
118
+
119
+
120
+ if __name__ == "__main__":
121
+
122
+ parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
123
+ parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
124
+ parser.add_argument(
125
+ "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
126
+ )
127
+ parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
128
+ parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
129
+ parser.add_argument(
130
+ "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
131
+ )
132
+
133
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
134
+ parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
135
+
136
+ parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
137
+ args = parser.parse_args()
138
+
139
+ # cfg
140
+ config_file = args.config_file # change the path of the model config file
141
+ checkpoint_path = args.checkpoint_path # change the path of the model
142
+ image_path = args.image_path
143
+ text_prompt = args.text_prompt
144
+ output_dir = args.output_dir
145
+ box_threshold = args.box_threshold
146
+ text_threshold = args.text_threshold
147
+
148
+ # make dir
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ # load image
151
+ image_pil, image = load_image(image_path)
152
+ # load model
153
+ model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
154
+
155
+ # visualize raw image
156
+ image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
157
+
158
+ # run model
159
+ boxes_filt, pred_phrases = get_grounding_output(
160
+ model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
161
+ )
162
+
163
+ # visualize pred
164
+ size = image_pil.size
165
+ pred_dict = {
166
+ "boxes": boxes_filt,
167
+ "size": [size[1], size[0]], # H,W
168
+ "labels": pred_phrases,
169
+ }
170
+ # import ipdb; ipdb.set_trace()
171
+ image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
172
+ image_with_box.save(os.path.join(output_dir, "pred.jpg"))
GroundingDINO/groundingdino/__init__.py ADDED
File without changes
GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_B_384_22k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_T_224_1k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
GroundingDINO/groundingdino/datasets/__init__.py ADDED
File without changes
GroundingDINO/groundingdino/datasets/transforms.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Transforms and data augmentation for both image + bbox.
4
+ """
5
+ import os
6
+ import random
7
+
8
+ import PIL
9
+ import torch
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as F
12
+
13
+ from groundingdino.util.box_ops import box_xyxy_to_cxcywh
14
+ from groundingdino.util.misc import interpolate
15
+
16
+
17
+ def crop(image, target, region):
18
+ cropped_image = F.crop(image, *region)
19
+
20
+ target = target.copy()
21
+ i, j, h, w = region
22
+
23
+ # should we do something wrt the original size?
24
+ target["size"] = torch.tensor([h, w])
25
+
26
+ fields = ["labels", "area", "iscrowd", "positive_map"]
27
+
28
+ if "boxes" in target:
29
+ boxes = target["boxes"]
30
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
31
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
32
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
33
+ cropped_boxes = cropped_boxes.clamp(min=0)
34
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
35
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
36
+ target["area"] = area
37
+ fields.append("boxes")
38
+
39
+ if "masks" in target:
40
+ # FIXME should we update the area here if there are no boxes?
41
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
42
+ fields.append("masks")
43
+
44
+ # remove elements for which the boxes or masks that have zero area
45
+ if "boxes" in target or "masks" in target:
46
+ # favor boxes selection when defining which elements to keep
47
+ # this is compatible with previous implementation
48
+ if "boxes" in target:
49
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
50
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
51
+ else:
52
+ keep = target["masks"].flatten(1).any(1)
53
+
54
+ for field in fields:
55
+ if field in target:
56
+ target[field] = target[field][keep]
57
+
58
+ if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
59
+ # for debug and visualization only.
60
+ if "strings_positive" in target:
61
+ target["strings_positive"] = [
62
+ _i for _i, _j in zip(target["strings_positive"], keep) if _j
63
+ ]
64
+
65
+ return cropped_image, target
66
+
67
+
68
+ def hflip(image, target):
69
+ flipped_image = F.hflip(image)
70
+
71
+ w, h = image.size
72
+
73
+ target = target.copy()
74
+ if "boxes" in target:
75
+ boxes = target["boxes"]
76
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
77
+ [w, 0, w, 0]
78
+ )
79
+ target["boxes"] = boxes
80
+
81
+ if "masks" in target:
82
+ target["masks"] = target["masks"].flip(-1)
83
+
84
+ return flipped_image, target
85
+
86
+
87
+ def resize(image, target, size, max_size=None):
88
+ # size can be min_size (scalar) or (w, h) tuple
89
+
90
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
91
+ w, h = image_size
92
+ if max_size is not None:
93
+ min_original_size = float(min((w, h)))
94
+ max_original_size = float(max((w, h)))
95
+ if max_original_size / min_original_size * size > max_size:
96
+ size = int(round(max_size * min_original_size / max_original_size))
97
+
98
+ if (w <= h and w == size) or (h <= w and h == size):
99
+ return (h, w)
100
+
101
+ if w < h:
102
+ ow = size
103
+ oh = int(size * h / w)
104
+ else:
105
+ oh = size
106
+ ow = int(size * w / h)
107
+
108
+ return (oh, ow)
109
+
110
+ def get_size(image_size, size, max_size=None):
111
+ if isinstance(size, (list, tuple)):
112
+ return size[::-1]
113
+ else:
114
+ return get_size_with_aspect_ratio(image_size, size, max_size)
115
+
116
+ size = get_size(image.size, size, max_size)
117
+ rescaled_image = F.resize(image, size)
118
+
119
+ if target is None:
120
+ return rescaled_image, None
121
+
122
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
123
+ ratio_width, ratio_height = ratios
124
+
125
+ target = target.copy()
126
+ if "boxes" in target:
127
+ boxes = target["boxes"]
128
+ scaled_boxes = boxes * torch.as_tensor(
129
+ [ratio_width, ratio_height, ratio_width, ratio_height]
130
+ )
131
+ target["boxes"] = scaled_boxes
132
+
133
+ if "area" in target:
134
+ area = target["area"]
135
+ scaled_area = area * (ratio_width * ratio_height)
136
+ target["area"] = scaled_area
137
+
138
+ h, w = size
139
+ target["size"] = torch.tensor([h, w])
140
+
141
+ if "masks" in target:
142
+ target["masks"] = (
143
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
144
+ )
145
+
146
+ return rescaled_image, target
147
+
148
+
149
+ def pad(image, target, padding):
150
+ # assumes that we only pad on the bottom right corners
151
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
152
+ if target is None:
153
+ return padded_image, None
154
+ target = target.copy()
155
+ # should we do something wrt the original size?
156
+ target["size"] = torch.tensor(padded_image.size[::-1])
157
+ if "masks" in target:
158
+ target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
159
+ return padded_image, target
160
+
161
+
162
+ class ResizeDebug(object):
163
+ def __init__(self, size):
164
+ self.size = size
165
+
166
+ def __call__(self, img, target):
167
+ return resize(img, target, self.size)
168
+
169
+
170
+ class RandomCrop(object):
171
+ def __init__(self, size):
172
+ self.size = size
173
+
174
+ def __call__(self, img, target):
175
+ region = T.RandomCrop.get_params(img, self.size)
176
+ return crop(img, target, region)
177
+
178
+
179
+ class RandomSizeCrop(object):
180
+ def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
181
+ # respect_boxes: True to keep all boxes
182
+ # False to tolerence box filter
183
+ self.min_size = min_size
184
+ self.max_size = max_size
185
+ self.respect_boxes = respect_boxes
186
+
187
+ def __call__(self, img: PIL.Image.Image, target: dict):
188
+ init_boxes = len(target["boxes"])
189
+ max_patience = 10
190
+ for i in range(max_patience):
191
+ w = random.randint(self.min_size, min(img.width, self.max_size))
192
+ h = random.randint(self.min_size, min(img.height, self.max_size))
193
+ region = T.RandomCrop.get_params(img, [h, w])
194
+ result_img, result_target = crop(img, target, region)
195
+ if (
196
+ not self.respect_boxes
197
+ or len(result_target["boxes"]) == init_boxes
198
+ or i == max_patience - 1
199
+ ):
200
+ return result_img, result_target
201
+ return result_img, result_target
202
+
203
+
204
+ class CenterCrop(object):
205
+ def __init__(self, size):
206
+ self.size = size
207
+
208
+ def __call__(self, img, target):
209
+ image_width, image_height = img.size
210
+ crop_height, crop_width = self.size
211
+ crop_top = int(round((image_height - crop_height) / 2.0))
212
+ crop_left = int(round((image_width - crop_width) / 2.0))
213
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
214
+
215
+
216
+ class RandomHorizontalFlip(object):
217
+ def __init__(self, p=0.5):
218
+ self.p = p
219
+
220
+ def __call__(self, img, target):
221
+ if random.random() < self.p:
222
+ return hflip(img, target)
223
+ return img, target
224
+
225
+
226
+ class RandomResize(object):
227
+ def __init__(self, sizes, max_size=None):
228
+ assert isinstance(sizes, (list, tuple))
229
+ self.sizes = sizes
230
+ self.max_size = max_size
231
+
232
+ def __call__(self, img, target=None):
233
+ size = random.choice(self.sizes)
234
+ return resize(img, target, size, self.max_size)
235
+
236
+
237
+ class RandomPad(object):
238
+ def __init__(self, max_pad):
239
+ self.max_pad = max_pad
240
+
241
+ def __call__(self, img, target):
242
+ pad_x = random.randint(0, self.max_pad)
243
+ pad_y = random.randint(0, self.max_pad)
244
+ return pad(img, target, (pad_x, pad_y))
245
+
246
+
247
+ class RandomSelect(object):
248
+ """
249
+ Randomly selects between transforms1 and transforms2,
250
+ with probability p for transforms1 and (1 - p) for transforms2
251
+ """
252
+
253
+ def __init__(self, transforms1, transforms2, p=0.5):
254
+ self.transforms1 = transforms1
255
+ self.transforms2 = transforms2
256
+ self.p = p
257
+
258
+ def __call__(self, img, target):
259
+ if random.random() < self.p:
260
+ return self.transforms1(img, target)
261
+ return self.transforms2(img, target)
262
+
263
+
264
+ class ToTensor(object):
265
+ def __call__(self, img, target):
266
+ return F.to_tensor(img), target
267
+
268
+
269
+ class RandomErasing(object):
270
+ def __init__(self, *args, **kwargs):
271
+ self.eraser = T.RandomErasing(*args, **kwargs)
272
+
273
+ def __call__(self, img, target):
274
+ return self.eraser(img), target
275
+
276
+
277
+ class Normalize(object):
278
+ def __init__(self, mean, std):
279
+ self.mean = mean
280
+ self.std = std
281
+
282
+ def __call__(self, image, target=None):
283
+ image = F.normalize(image, mean=self.mean, std=self.std)
284
+ if target is None:
285
+ return image, None
286
+ target = target.copy()
287
+ h, w = image.shape[-2:]
288
+ if "boxes" in target:
289
+ boxes = target["boxes"]
290
+ boxes = box_xyxy_to_cxcywh(boxes)
291
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
292
+ target["boxes"] = boxes
293
+ return image, target
294
+
295
+
296
+ class Compose(object):
297
+ def __init__(self, transforms):
298
+ self.transforms = transforms
299
+
300
+ def __call__(self, image, target):
301
+ for t in self.transforms:
302
+ image, target = t(image, target)
303
+ return image, target
304
+
305
+ def __repr__(self):
306
+ format_string = self.__class__.__name__ + "("
307
+ for t in self.transforms:
308
+ format_string += "\n"
309
+ format_string += " {0}".format(t)
310
+ format_string += "\n)"
311
+ return format_string
GroundingDINO/groundingdino/models/GroundingDINO/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ from .groundingdino import build_groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .backbone import build_backbone
GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ """
16
+ Backbone modules.
17
+ """
18
+
19
+ from typing import Dict, List
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchvision
24
+ from torch import nn
25
+ from torchvision.models._utils import IntermediateLayerGetter
26
+
27
+ from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
28
+
29
+ from .position_encoding import build_position_encoding
30
+ from .swin_transformer import build_swin_transformer
31
+
32
+
33
+ class FrozenBatchNorm2d(torch.nn.Module):
34
+ """
35
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
36
+
37
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
38
+ without which any other models than torchvision.models.resnet[18,34,50,101]
39
+ produce nans.
40
+ """
41
+
42
+ def __init__(self, n):
43
+ super(FrozenBatchNorm2d, self).__init__()
44
+ self.register_buffer("weight", torch.ones(n))
45
+ self.register_buffer("bias", torch.zeros(n))
46
+ self.register_buffer("running_mean", torch.zeros(n))
47
+ self.register_buffer("running_var", torch.ones(n))
48
+
49
+ def _load_from_state_dict(
50
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
51
+ ):
52
+ num_batches_tracked_key = prefix + "num_batches_tracked"
53
+ if num_batches_tracked_key in state_dict:
54
+ del state_dict[num_batches_tracked_key]
55
+
56
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
57
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
58
+ )
59
+
60
+ def forward(self, x):
61
+ # move reshapes to the beginning
62
+ # to make it fuser-friendly
63
+ w = self.weight.reshape(1, -1, 1, 1)
64
+ b = self.bias.reshape(1, -1, 1, 1)
65
+ rv = self.running_var.reshape(1, -1, 1, 1)
66
+ rm = self.running_mean.reshape(1, -1, 1, 1)
67
+ eps = 1e-5
68
+ scale = w * (rv + eps).rsqrt()
69
+ bias = b - rm * scale
70
+ return x * scale + bias
71
+
72
+
73
+ class BackboneBase(nn.Module):
74
+ def __init__(
75
+ self,
76
+ backbone: nn.Module,
77
+ train_backbone: bool,
78
+ num_channels: int,
79
+ return_interm_indices: list,
80
+ ):
81
+ super().__init__()
82
+ for name, parameter in backbone.named_parameters():
83
+ if (
84
+ not train_backbone
85
+ or "layer2" not in name
86
+ and "layer3" not in name
87
+ and "layer4" not in name
88
+ ):
89
+ parameter.requires_grad_(False)
90
+
91
+ return_layers = {}
92
+ for idx, layer_index in enumerate(return_interm_indices):
93
+ return_layers.update(
94
+ {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
95
+ )
96
+
97
+ # if len:
98
+ # if use_stage1_feature:
99
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
100
+ # else:
101
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
102
+ # else:
103
+ # return_layers = {'layer4': "0"}
104
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
105
+ self.num_channels = num_channels
106
+
107
+ def forward(self, tensor_list: NestedTensor):
108
+ xs = self.body(tensor_list.tensors)
109
+ out: Dict[str, NestedTensor] = {}
110
+ for name, x in xs.items():
111
+ m = tensor_list.mask
112
+ assert m is not None
113
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
114
+ out[name] = NestedTensor(x, mask)
115
+ # import ipdb; ipdb.set_trace()
116
+ return out
117
+
118
+
119
+ class Backbone(BackboneBase):
120
+ """ResNet backbone with frozen BatchNorm."""
121
+
122
+ def __init__(
123
+ self,
124
+ name: str,
125
+ train_backbone: bool,
126
+ dilation: bool,
127
+ return_interm_indices: list,
128
+ batch_norm=FrozenBatchNorm2d,
129
+ ):
130
+ if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
131
+ backbone = getattr(torchvision.models, name)(
132
+ replace_stride_with_dilation=[False, False, dilation],
133
+ pretrained=is_main_process(),
134
+ norm_layer=batch_norm,
135
+ )
136
+ else:
137
+ raise NotImplementedError("Why you can get here with name {}".format(name))
138
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
139
+ assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
140
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
141
+ num_channels_all = [256, 512, 1024, 2048]
142
+ num_channels = num_channels_all[4 - len(return_interm_indices) :]
143
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
144
+
145
+
146
+ class Joiner(nn.Sequential):
147
+ def __init__(self, backbone, position_embedding):
148
+ super().__init__(backbone, position_embedding)
149
+
150
+ def forward(self, tensor_list: NestedTensor):
151
+ xs = self[0](tensor_list)
152
+ out: List[NestedTensor] = []
153
+ pos = []
154
+ for name, x in xs.items():
155
+ out.append(x)
156
+ # position encoding
157
+ pos.append(self[1](x).to(x.tensors.dtype))
158
+
159
+ return out, pos
160
+
161
+
162
+ def build_backbone(args):
163
+ """
164
+ Useful args:
165
+ - backbone: backbone name
166
+ - lr_backbone:
167
+ - dilation
168
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
169
+ - backbone_freeze_keywords:
170
+ - use_checkpoint: for swin only for now
171
+
172
+ """
173
+ position_embedding = build_position_encoding(args)
174
+ train_backbone = True
175
+ if not train_backbone:
176
+ raise ValueError("Please set lr_backbone > 0")
177
+ return_interm_indices = args.return_interm_indices
178
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
179
+ args.backbone_freeze_keywords
180
+ use_checkpoint = getattr(args, "use_checkpoint", False)
181
+
182
+ if args.backbone in ["resnet50", "resnet101"]:
183
+ backbone = Backbone(
184
+ args.backbone,
185
+ train_backbone,
186
+ args.dilation,
187
+ return_interm_indices,
188
+ batch_norm=FrozenBatchNorm2d,
189
+ )
190
+ bb_num_channels = backbone.num_channels
191
+ elif args.backbone in [
192
+ "swin_T_224_1k",
193
+ "swin_B_224_22k",
194
+ "swin_B_384_22k",
195
+ "swin_L_224_22k",
196
+ "swin_L_384_22k",
197
+ ]:
198
+ pretrain_img_size = int(args.backbone.split("_")[-2])
199
+ backbone = build_swin_transformer(
200
+ args.backbone,
201
+ pretrain_img_size=pretrain_img_size,
202
+ out_indices=tuple(return_interm_indices),
203
+ dilation=False,
204
+ use_checkpoint=use_checkpoint,
205
+ )
206
+
207
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
208
+ else:
209
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
210
+
211
+ assert len(bb_num_channels) == len(
212
+ return_interm_indices
213
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
214
+
215
+ model = Joiner(backbone, position_embedding)
216
+ model.num_channels = bb_num_channels
217
+ assert isinstance(
218
+ bb_num_channels, List
219
+ ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
220
+ # import ipdb; ipdb.set_trace()
221
+ return model
GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Copied from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ """
20
+ Various positional encodings for the transformer.
21
+ """
22
+ import math
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from groundingdino.util.misc import NestedTensor
28
+
29
+
30
+ class PositionEmbeddingSine(nn.Module):
31
+ """
32
+ This is a more standard version of the position embedding, very similar to the one
33
+ used by the Attention is all you need paper, generalized to work on images.
34
+ """
35
+
36
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
37
+ super().__init__()
38
+ self.num_pos_feats = num_pos_feats
39
+ self.temperature = temperature
40
+ self.normalize = normalize
41
+ if scale is not None and normalize is False:
42
+ raise ValueError("normalize should be True if scale is passed")
43
+ if scale is None:
44
+ scale = 2 * math.pi
45
+ self.scale = scale
46
+
47
+ def forward(self, tensor_list: NestedTensor):
48
+ x = tensor_list.tensors
49
+ mask = tensor_list.mask
50
+ assert mask is not None
51
+ not_mask = ~mask
52
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
53
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
54
+ if self.normalize:
55
+ eps = 1e-6
56
+ # if os.environ.get("SHILONG_AMP", None) == '1':
57
+ # eps = 1e-4
58
+ # else:
59
+ # eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ return pos
76
+
77
+
78
+ class PositionEmbeddingSineHW(nn.Module):
79
+ """
80
+ This is a more standard version of the position embedding, very similar to the one
81
+ used by the Attention is all you need paper, generalized to work on images.
82
+ """
83
+
84
+ def __init__(
85
+ self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
86
+ ):
87
+ super().__init__()
88
+ self.num_pos_feats = num_pos_feats
89
+ self.temperatureH = temperatureH
90
+ self.temperatureW = temperatureW
91
+ self.normalize = normalize
92
+ if scale is not None and normalize is False:
93
+ raise ValueError("normalize should be True if scale is passed")
94
+ if scale is None:
95
+ scale = 2 * math.pi
96
+ self.scale = scale
97
+
98
+ def forward(self, tensor_list: NestedTensor):
99
+ x = tensor_list.tensors
100
+ mask = tensor_list.mask
101
+ assert mask is not None
102
+ not_mask = ~mask
103
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
104
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
105
+
106
+ # import ipdb; ipdb.set_trace()
107
+
108
+ if self.normalize:
109
+ eps = 1e-6
110
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
111
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
112
+
113
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
114
+ dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
115
+ pos_x = x_embed[:, :, :, None] / dim_tx
116
+
117
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
118
+ dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
119
+ pos_y = y_embed[:, :, :, None] / dim_ty
120
+
121
+ pos_x = torch.stack(
122
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
123
+ ).flatten(3)
124
+ pos_y = torch.stack(
125
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
126
+ ).flatten(3)
127
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
128
+
129
+ # import ipdb; ipdb.set_trace()
130
+
131
+ return pos
132
+
133
+
134
+ class PositionEmbeddingLearned(nn.Module):
135
+ """
136
+ Absolute pos embedding, learned.
137
+ """
138
+
139
+ def __init__(self, num_pos_feats=256):
140
+ super().__init__()
141
+ self.row_embed = nn.Embedding(50, num_pos_feats)
142
+ self.col_embed = nn.Embedding(50, num_pos_feats)
143
+ self.reset_parameters()
144
+
145
+ def reset_parameters(self):
146
+ nn.init.uniform_(self.row_embed.weight)
147
+ nn.init.uniform_(self.col_embed.weight)
148
+
149
+ def forward(self, tensor_list: NestedTensor):
150
+ x = tensor_list.tensors
151
+ h, w = x.shape[-2:]
152
+ i = torch.arange(w, device=x.device)
153
+ j = torch.arange(h, device=x.device)
154
+ x_emb = self.col_embed(i)
155
+ y_emb = self.row_embed(j)
156
+ pos = (
157
+ torch.cat(
158
+ [
159
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
160
+ y_emb.unsqueeze(1).repeat(1, w, 1),
161
+ ],
162
+ dim=-1,
163
+ )
164
+ .permute(2, 0, 1)
165
+ .unsqueeze(0)
166
+ .repeat(x.shape[0], 1, 1, 1)
167
+ )
168
+ return pos
169
+
170
+
171
+ def build_position_encoding(args):
172
+ N_steps = args.hidden_dim // 2
173
+ if args.position_embedding in ("v2", "sine"):
174
+ # TODO find a better way of exposing other arguments
175
+ position_embedding = PositionEmbeddingSineHW(
176
+ N_steps,
177
+ temperatureH=args.pe_temperatureH,
178
+ temperatureW=args.pe_temperatureW,
179
+ normalize=True,
180
+ )
181
+ elif args.position_embedding in ("v3", "learned"):
182
+ position_embedding = PositionEmbeddingLearned(N_steps)
183
+ else:
184
+ raise ValueError(f"not supported {args.position_embedding}")
185
+
186
+ return position_embedding
GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # --------------------------------------------------------
11
+ # modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
12
+ # --------------------------------------------------------
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
20
+
21
+ from groundingdino.util.misc import NestedTensor
22
+
23
+
24
+ class Mlp(nn.Module):
25
+ """Multilayer perceptron."""
26
+
27
+ def __init__(
28
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
29
+ ):
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ self.fc1 = nn.Linear(in_features, hidden_features)
34
+ self.act = act_layer()
35
+ self.fc2 = nn.Linear(hidden_features, out_features)
36
+ self.drop = nn.Dropout(drop)
37
+
38
+ def forward(self, x):
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop(x)
42
+ x = self.fc2(x)
43
+ x = self.drop(x)
44
+ return x
45
+
46
+
47
+ def window_partition(x, window_size):
48
+ """
49
+ Args:
50
+ x: (B, H, W, C)
51
+ window_size (int): window size
52
+ Returns:
53
+ windows: (num_windows*B, window_size, window_size, C)
54
+ """
55
+ B, H, W, C = x.shape
56
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
57
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
58
+ return windows
59
+
60
+
61
+ def window_reverse(windows, window_size, H, W):
62
+ """
63
+ Args:
64
+ windows: (num_windows*B, window_size, window_size, C)
65
+ window_size (int): Window size
66
+ H (int): Height of image
67
+ W (int): Width of image
68
+ Returns:
69
+ x: (B, H, W, C)
70
+ """
71
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
72
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
73
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
74
+ return x
75
+
76
+
77
+ class WindowAttention(nn.Module):
78
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
79
+ It supports both of shifted and non-shifted window.
80
+ Args:
81
+ dim (int): Number of input channels.
82
+ window_size (tuple[int]): The height and width of the window.
83
+ num_heads (int): Number of attention heads.
84
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
85
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
86
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
87
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ dim,
93
+ window_size,
94
+ num_heads,
95
+ qkv_bias=True,
96
+ qk_scale=None,
97
+ attn_drop=0.0,
98
+ proj_drop=0.0,
99
+ ):
100
+
101
+ super().__init__()
102
+ self.dim = dim
103
+ self.window_size = window_size # Wh, Ww
104
+ self.num_heads = num_heads
105
+ head_dim = dim // num_heads
106
+ self.scale = qk_scale or head_dim**-0.5
107
+
108
+ # define a parameter table of relative position bias
109
+ self.relative_position_bias_table = nn.Parameter(
110
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
111
+ ) # 2*Wh-1 * 2*Ww-1, nH
112
+
113
+ # get pair-wise relative position index for each token inside the window
114
+ coords_h = torch.arange(self.window_size[0])
115
+ coords_w = torch.arange(self.window_size[1])
116
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
117
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
118
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
119
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
120
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
121
+ relative_coords[:, :, 1] += self.window_size[1] - 1
122
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
123
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
124
+ self.register_buffer("relative_position_index", relative_position_index)
125
+
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.attn_drop = nn.Dropout(attn_drop)
128
+ self.proj = nn.Linear(dim, dim)
129
+ self.proj_drop = nn.Dropout(proj_drop)
130
+
131
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
132
+ self.softmax = nn.Softmax(dim=-1)
133
+
134
+ def forward(self, x, mask=None):
135
+ """Forward function.
136
+ Args:
137
+ x: input features with shape of (num_windows*B, N, C)
138
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
139
+ """
140
+ B_, N, C = x.shape
141
+ qkv = (
142
+ self.qkv(x)
143
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
144
+ .permute(2, 0, 3, 1, 4)
145
+ )
146
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
147
+
148
+ q = q * self.scale
149
+ attn = q @ k.transpose(-2, -1)
150
+
151
+ relative_position_bias = self.relative_position_bias_table[
152
+ self.relative_position_index.view(-1)
153
+ ].view(
154
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
155
+ ) # Wh*Ww,Wh*Ww,nH
156
+ relative_position_bias = relative_position_bias.permute(
157
+ 2, 0, 1
158
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
159
+ attn = attn + relative_position_bias.unsqueeze(0)
160
+
161
+ if mask is not None:
162
+ nW = mask.shape[0]
163
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
+ attn = attn.view(-1, self.num_heads, N, N)
165
+ attn = self.softmax(attn)
166
+ else:
167
+ attn = self.softmax(attn)
168
+
169
+ attn = self.attn_drop(attn)
170
+
171
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class SwinTransformerBlock(nn.Module):
178
+ """Swin Transformer Block.
179
+ Args:
180
+ dim (int): Number of input channels.
181
+ num_heads (int): Number of attention heads.
182
+ window_size (int): Window size.
183
+ shift_size (int): Shift size for SW-MSA.
184
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
185
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
186
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
187
+ drop (float, optional): Dropout rate. Default: 0.0
188
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
189
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
190
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
191
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ dim,
197
+ num_heads,
198
+ window_size=7,
199
+ shift_size=0,
200
+ mlp_ratio=4.0,
201
+ qkv_bias=True,
202
+ qk_scale=None,
203
+ drop=0.0,
204
+ attn_drop=0.0,
205
+ drop_path=0.0,
206
+ act_layer=nn.GELU,
207
+ norm_layer=nn.LayerNorm,
208
+ ):
209
+ super().__init__()
210
+ self.dim = dim
211
+ self.num_heads = num_heads
212
+ self.window_size = window_size
213
+ self.shift_size = shift_size
214
+ self.mlp_ratio = mlp_ratio
215
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
216
+
217
+ self.norm1 = norm_layer(dim)
218
+ self.attn = WindowAttention(
219
+ dim,
220
+ window_size=to_2tuple(self.window_size),
221
+ num_heads=num_heads,
222
+ qkv_bias=qkv_bias,
223
+ qk_scale=qk_scale,
224
+ attn_drop=attn_drop,
225
+ proj_drop=drop,
226
+ )
227
+
228
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
229
+ self.norm2 = norm_layer(dim)
230
+ mlp_hidden_dim = int(dim * mlp_ratio)
231
+ self.mlp = Mlp(
232
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
233
+ )
234
+
235
+ self.H = None
236
+ self.W = None
237
+
238
+ def forward(self, x, mask_matrix):
239
+ """Forward function.
240
+ Args:
241
+ x: Input feature, tensor size (B, H*W, C).
242
+ H, W: Spatial resolution of the input feature.
243
+ mask_matrix: Attention mask for cyclic shift.
244
+ """
245
+ B, L, C = x.shape
246
+ H, W = self.H, self.W
247
+ assert L == H * W, "input feature has wrong size"
248
+
249
+ shortcut = x
250
+ x = self.norm1(x)
251
+ x = x.view(B, H, W, C)
252
+
253
+ # pad feature maps to multiples of window size
254
+ pad_l = pad_t = 0
255
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
256
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
257
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
258
+ _, Hp, Wp, _ = x.shape
259
+
260
+ # cyclic shift
261
+ if self.shift_size > 0:
262
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
263
+ attn_mask = mask_matrix
264
+ else:
265
+ shifted_x = x
266
+ attn_mask = None
267
+
268
+ # partition windows
269
+ x_windows = window_partition(
270
+ shifted_x, self.window_size
271
+ ) # nW*B, window_size, window_size, C
272
+ x_windows = x_windows.view(
273
+ -1, self.window_size * self.window_size, C
274
+ ) # nW*B, window_size*window_size, C
275
+
276
+ # W-MSA/SW-MSA
277
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
278
+
279
+ # merge windows
280
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
281
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
282
+
283
+ # reverse cyclic shift
284
+ if self.shift_size > 0:
285
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
286
+ else:
287
+ x = shifted_x
288
+
289
+ if pad_r > 0 or pad_b > 0:
290
+ x = x[:, :H, :W, :].contiguous()
291
+
292
+ x = x.view(B, H * W, C)
293
+
294
+ # FFN
295
+ x = shortcut + self.drop_path(x)
296
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
297
+
298
+ return x
299
+
300
+
301
+ class PatchMerging(nn.Module):
302
+ """Patch Merging Layer
303
+ Args:
304
+ dim (int): Number of input channels.
305
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
306
+ """
307
+
308
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
309
+ super().__init__()
310
+ self.dim = dim
311
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
312
+ self.norm = norm_layer(4 * dim)
313
+
314
+ def forward(self, x, H, W):
315
+ """Forward function.
316
+ Args:
317
+ x: Input feature, tensor size (B, H*W, C).
318
+ H, W: Spatial resolution of the input feature.
319
+ """
320
+ B, L, C = x.shape
321
+ assert L == H * W, "input feature has wrong size"
322
+
323
+ x = x.view(B, H, W, C)
324
+
325
+ # padding
326
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
327
+ if pad_input:
328
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
329
+
330
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
331
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
332
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
333
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
334
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
335
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
336
+
337
+ x = self.norm(x)
338
+ x = self.reduction(x)
339
+
340
+ return x
341
+
342
+
343
+ class BasicLayer(nn.Module):
344
+ """A basic Swin Transformer layer for one stage.
345
+ Args:
346
+ dim (int): Number of feature channels
347
+ depth (int): Depths of this stage.
348
+ num_heads (int): Number of attention head.
349
+ window_size (int): Local window size. Default: 7.
350
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
351
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
352
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
353
+ drop (float, optional): Dropout rate. Default: 0.0
354
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
355
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
356
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
357
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
358
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ dim,
364
+ depth,
365
+ num_heads,
366
+ window_size=7,
367
+ mlp_ratio=4.0,
368
+ qkv_bias=True,
369
+ qk_scale=None,
370
+ drop=0.0,
371
+ attn_drop=0.0,
372
+ drop_path=0.0,
373
+ norm_layer=nn.LayerNorm,
374
+ downsample=None,
375
+ use_checkpoint=False,
376
+ ):
377
+ super().__init__()
378
+ self.window_size = window_size
379
+ self.shift_size = window_size // 2
380
+ self.depth = depth
381
+ self.use_checkpoint = use_checkpoint
382
+
383
+ # build blocks
384
+ self.blocks = nn.ModuleList(
385
+ [
386
+ SwinTransformerBlock(
387
+ dim=dim,
388
+ num_heads=num_heads,
389
+ window_size=window_size,
390
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
391
+ mlp_ratio=mlp_ratio,
392
+ qkv_bias=qkv_bias,
393
+ qk_scale=qk_scale,
394
+ drop=drop,
395
+ attn_drop=attn_drop,
396
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
397
+ norm_layer=norm_layer,
398
+ )
399
+ for i in range(depth)
400
+ ]
401
+ )
402
+
403
+ # patch merging layer
404
+ if downsample is not None:
405
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
406
+ else:
407
+ self.downsample = None
408
+
409
+ def forward(self, x, H, W):
410
+ """Forward function.
411
+ Args:
412
+ x: Input feature, tensor size (B, H*W, C).
413
+ H, W: Spatial resolution of the input feature.
414
+ """
415
+
416
+ # calculate attention mask for SW-MSA
417
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
418
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
419
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device, dtype=x.dtype) # 1 Hp Wp 1
420
+ h_slices = (
421
+ slice(0, -self.window_size),
422
+ slice(-self.window_size, -self.shift_size),
423
+ slice(-self.shift_size, None),
424
+ )
425
+ w_slices = (
426
+ slice(0, -self.window_size),
427
+ slice(-self.window_size, -self.shift_size),
428
+ slice(-self.shift_size, None),
429
+ )
430
+ cnt = 0
431
+ for h in h_slices:
432
+ for w in w_slices:
433
+ img_mask[:, h, w, :] = cnt
434
+ cnt += 1
435
+
436
+ mask_windows = window_partition(
437
+ img_mask, self.window_size
438
+ ) # nW, window_size, window_size, 1
439
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
440
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
441
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
442
+ attn_mask == 0, float(0.0)
443
+ )
444
+
445
+ for blk in self.blocks:
446
+ blk.H, blk.W = H, W
447
+ if self.use_checkpoint:
448
+ x = checkpoint.checkpoint(blk, x, attn_mask)
449
+ else:
450
+ x = blk(x, attn_mask)
451
+ if self.downsample is not None:
452
+ x_down = self.downsample(x, H, W)
453
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
454
+ return x, H, W, x_down, Wh, Ww
455
+ else:
456
+ return x, H, W, x, H, W
457
+
458
+
459
+ class PatchEmbed(nn.Module):
460
+ """Image to Patch Embedding
461
+ Args:
462
+ patch_size (int): Patch token size. Default: 4.
463
+ in_chans (int): Number of input image channels. Default: 3.
464
+ embed_dim (int): Number of linear projection output channels. Default: 96.
465
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
466
+ """
467
+
468
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
469
+ super().__init__()
470
+ patch_size = to_2tuple(patch_size)
471
+ self.patch_size = patch_size
472
+
473
+ self.in_chans = in_chans
474
+ self.embed_dim = embed_dim
475
+
476
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
477
+ if norm_layer is not None:
478
+ self.norm = norm_layer(embed_dim)
479
+ else:
480
+ self.norm = None
481
+
482
+ def forward(self, x):
483
+ """Forward function."""
484
+ # padding
485
+ _, _, H, W = x.size()
486
+ if W % self.patch_size[1] != 0:
487
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
488
+ if H % self.patch_size[0] != 0:
489
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
490
+
491
+ x = self.proj(x) # B C Wh Ww
492
+ if self.norm is not None:
493
+ Wh, Ww = x.size(2), x.size(3)
494
+ x = x.flatten(2).transpose(1, 2)
495
+ x = self.norm(x)
496
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
497
+
498
+ return x
499
+
500
+
501
+ class SwinTransformer(nn.Module):
502
+ """Swin Transformer backbone.
503
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
504
+ https://arxiv.org/pdf/2103.14030
505
+ Args:
506
+ pretrain_img_size (int): Input image size for training the pretrained model,
507
+ used in absolute postion embedding. Default 224.
508
+ patch_size (int | tuple(int)): Patch size. Default: 4.
509
+ in_chans (int): Number of input image channels. Default: 3.
510
+ embed_dim (int): Number of linear projection output channels. Default: 96.
511
+ depths (tuple[int]): Depths of each Swin Transformer stage.
512
+ num_heads (tuple[int]): Number of attention head of each stage.
513
+ window_size (int): Window size. Default: 7.
514
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
515
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
516
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
517
+ drop_rate (float): Dropout rate.
518
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
519
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
520
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
521
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
522
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
523
+ out_indices (Sequence[int]): Output from which stages.
524
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
525
+ -1 means not freezing any parameters.
526
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
527
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ pretrain_img_size=224,
533
+ patch_size=4,
534
+ in_chans=3,
535
+ embed_dim=96,
536
+ depths=[2, 2, 6, 2],
537
+ num_heads=[3, 6, 12, 24],
538
+ window_size=7,
539
+ mlp_ratio=4.0,
540
+ qkv_bias=True,
541
+ qk_scale=None,
542
+ drop_rate=0.0,
543
+ attn_drop_rate=0.0,
544
+ drop_path_rate=0.2,
545
+ norm_layer=nn.LayerNorm,
546
+ ape=False,
547
+ patch_norm=True,
548
+ out_indices=(0, 1, 2, 3),
549
+ frozen_stages=-1,
550
+ dilation=False,
551
+ use_checkpoint=False,
552
+ ):
553
+ super().__init__()
554
+
555
+ self.pretrain_img_size = pretrain_img_size
556
+ self.num_layers = len(depths)
557
+ self.embed_dim = embed_dim
558
+ self.ape = ape
559
+ self.patch_norm = patch_norm
560
+ self.out_indices = out_indices
561
+ self.frozen_stages = frozen_stages
562
+ self.dilation = dilation
563
+
564
+ # if use_checkpoint:
565
+ # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
566
+
567
+ # split image into non-overlapping patches
568
+ self.patch_embed = PatchEmbed(
569
+ patch_size=patch_size,
570
+ in_chans=in_chans,
571
+ embed_dim=embed_dim,
572
+ norm_layer=norm_layer if self.patch_norm else None,
573
+ )
574
+
575
+ # absolute position embedding
576
+ if self.ape:
577
+ pretrain_img_size = to_2tuple(pretrain_img_size)
578
+ patch_size = to_2tuple(patch_size)
579
+ patches_resolution = [
580
+ pretrain_img_size[0] // patch_size[0],
581
+ pretrain_img_size[1] // patch_size[1],
582
+ ]
583
+
584
+ self.absolute_pos_embed = nn.Parameter(
585
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
586
+ )
587
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
588
+
589
+ self.pos_drop = nn.Dropout(p=drop_rate)
590
+
591
+ # stochastic depth
592
+ dpr = [
593
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
594
+ ] # stochastic depth decay rule
595
+
596
+ # build layers
597
+ self.layers = nn.ModuleList()
598
+ # prepare downsample list
599
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
600
+ downsamplelist[-1] = None
601
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
602
+ if self.dilation:
603
+ downsamplelist[-2] = None
604
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
605
+ for i_layer in range(self.num_layers):
606
+ layer = BasicLayer(
607
+ # dim=int(embed_dim * 2 ** i_layer),
608
+ dim=num_features[i_layer],
609
+ depth=depths[i_layer],
610
+ num_heads=num_heads[i_layer],
611
+ window_size=window_size,
612
+ mlp_ratio=mlp_ratio,
613
+ qkv_bias=qkv_bias,
614
+ qk_scale=qk_scale,
615
+ drop=drop_rate,
616
+ attn_drop=attn_drop_rate,
617
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
618
+ norm_layer=norm_layer,
619
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
620
+ downsample=downsamplelist[i_layer],
621
+ use_checkpoint=use_checkpoint,
622
+ )
623
+ self.layers.append(layer)
624
+
625
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
626
+ self.num_features = num_features
627
+
628
+ # add a norm layer for each output
629
+ for i_layer in out_indices:
630
+ layer = norm_layer(num_features[i_layer])
631
+ layer_name = f"norm{i_layer}"
632
+ self.add_module(layer_name, layer)
633
+
634
+ self._freeze_stages()
635
+
636
+ def _freeze_stages(self):
637
+ if self.frozen_stages >= 0:
638
+ self.patch_embed.eval()
639
+ for param in self.patch_embed.parameters():
640
+ param.requires_grad = False
641
+
642
+ if self.frozen_stages >= 1 and self.ape:
643
+ self.absolute_pos_embed.requires_grad = False
644
+
645
+ if self.frozen_stages >= 2:
646
+ self.pos_drop.eval()
647
+ for i in range(0, self.frozen_stages - 1):
648
+ m = self.layers[i]
649
+ m.eval()
650
+ for param in m.parameters():
651
+ param.requires_grad = False
652
+
653
+ # def init_weights(self, pretrained=None):
654
+ # """Initialize the weights in backbone.
655
+ # Args:
656
+ # pretrained (str, optional): Path to pre-trained weights.
657
+ # Defaults to None.
658
+ # """
659
+
660
+ # def _init_weights(m):
661
+ # if isinstance(m, nn.Linear):
662
+ # trunc_normal_(m.weight, std=.02)
663
+ # if isinstance(m, nn.Linear) and m.bias is not None:
664
+ # nn.init.constant_(m.bias, 0)
665
+ # elif isinstance(m, nn.LayerNorm):
666
+ # nn.init.constant_(m.bias, 0)
667
+ # nn.init.constant_(m.weight, 1.0)
668
+
669
+ # if isinstance(pretrained, str):
670
+ # self.apply(_init_weights)
671
+ # logger = get_root_logger()
672
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
673
+ # elif pretrained is None:
674
+ # self.apply(_init_weights)
675
+ # else:
676
+ # raise TypeError('pretrained must be a str or None')
677
+
678
+ def forward_raw(self, x):
679
+ """Forward function."""
680
+ x = self.patch_embed(x)
681
+
682
+ Wh, Ww = x.size(2), x.size(3)
683
+ if self.ape:
684
+ # interpolate the position embedding to the corresponding size
685
+ absolute_pos_embed = F.interpolate(
686
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
687
+ )
688
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
689
+ else:
690
+ x = x.flatten(2).transpose(1, 2)
691
+ x = self.pos_drop(x)
692
+
693
+ outs = []
694
+ for i in range(self.num_layers):
695
+ layer = self.layers[i]
696
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
697
+ # import ipdb; ipdb.set_trace()
698
+
699
+ if i in self.out_indices:
700
+ norm_layer = getattr(self, f"norm{i}")
701
+ x_out = norm_layer(x_out)
702
+
703
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
704
+ outs.append(out)
705
+ # in:
706
+ # torch.Size([2, 3, 1024, 1024])
707
+ # outs:
708
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
709
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
710
+ return tuple(outs)
711
+
712
+ def forward(self, tensor_list: NestedTensor):
713
+ x = tensor_list.tensors
714
+
715
+ """Forward function."""
716
+ x = self.patch_embed(x)
717
+
718
+ Wh, Ww = x.size(2), x.size(3)
719
+ if self.ape:
720
+ # interpolate the position embedding to the corresponding size
721
+ absolute_pos_embed = F.interpolate(
722
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
723
+ )
724
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
725
+ else:
726
+ x = x.flatten(2).transpose(1, 2)
727
+ x = self.pos_drop(x)
728
+
729
+ outs = []
730
+ for i in range(self.num_layers):
731
+ layer = self.layers[i]
732
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
733
+
734
+ if i in self.out_indices:
735
+ norm_layer = getattr(self, f"norm{i}")
736
+ x_out = norm_layer(x_out)
737
+
738
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
739
+ outs.append(out)
740
+ # in:
741
+ # torch.Size([2, 3, 1024, 1024])
742
+ # out:
743
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
744
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
745
+
746
+ # collect for nesttensors
747
+ outs_dict = {}
748
+ for idx, out_i in enumerate(outs):
749
+ m = tensor_list.mask
750
+ assert m is not None
751
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
752
+ outs_dict[idx] = NestedTensor(out_i, mask)
753
+
754
+ return outs_dict
755
+
756
+ def train(self, mode=True):
757
+ """Convert the model into training mode while keep layers freezed."""
758
+ super(SwinTransformer, self).train(mode)
759
+ self._freeze_stages()
760
+
761
+
762
+ def build_swin_transformer(modelname, pretrain_img_size, **kw):
763
+ assert modelname in [
764
+ "swin_T_224_1k",
765
+ "swin_B_224_22k",
766
+ "swin_B_384_22k",
767
+ "swin_L_224_22k",
768
+ "swin_L_384_22k",
769
+ ]
770
+
771
+ model_para_dict = {
772
+ "swin_T_224_1k": dict(
773
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
774
+ ),
775
+ "swin_B_224_22k": dict(
776
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
777
+ ),
778
+ "swin_B_384_22k": dict(
779
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
780
+ ),
781
+ "swin_L_224_22k": dict(
782
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
783
+ ),
784
+ "swin_L_384_22k": dict(
785
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
786
+ ),
787
+ }
788
+ kw_cgf = model_para_dict[modelname]
789
+ kw_cgf.update(kw)
790
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
791
+ return model
792
+
793
+
794
+ if __name__ == "__main__":
795
+ model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
796
+ x = torch.rand(2, 3, 1024, 1024)
797
+ y = model.forward_raw(x)
798
+ import ipdb
799
+
800
+ ipdb.set_trace()
801
+ x = torch.rand(2, 3, 384, 384)
802
+ y = model.forward_raw(x)
GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from torch import Tensor, nn
12
+ from torchvision.ops.boxes import nms
13
+ from transformers import BertConfig, BertModel, BertPreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
+
16
+
17
+ class BertModelWarper(nn.Module):
18
+ def __init__(self, bert_model):
19
+ super().__init__()
20
+ # self.bert = bert_modelc
21
+
22
+ self.config = bert_model.config
23
+ self.embeddings = bert_model.embeddings
24
+ self.encoder = bert_model.encoder
25
+ self.pooler = bert_model.pooler
26
+
27
+ self.get_extended_attention_mask = bert_model.get_extended_attention_mask
28
+ self.invert_attention_mask = bert_model.invert_attention_mask
29
+ self.get_head_mask = bert_model.get_head_mask
30
+
31
+ def forward(
32
+ self,
33
+ input_ids=None,
34
+ attention_mask=None,
35
+ token_type_ids=None,
36
+ position_ids=None,
37
+ head_mask=None,
38
+ inputs_embeds=None,
39
+ encoder_hidden_states=None,
40
+ encoder_attention_mask=None,
41
+ past_key_values=None,
42
+ use_cache=None,
43
+ output_attentions=None,
44
+ output_hidden_states=None,
45
+ return_dict=None,
46
+ ):
47
+ r"""
48
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
49
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
50
+ the model is configured as a decoder.
51
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
52
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
53
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
54
+
55
+ - 1 for tokens that are **not masked**,
56
+ - 0 for tokens that are **masked**.
57
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
58
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
59
+
60
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
61
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
62
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
63
+ use_cache (:obj:`bool`, `optional`):
64
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
65
+ decoding (see :obj:`past_key_values`).
66
+ """
67
+ output_attentions = (
68
+ output_attentions if output_attentions is not None else self.config.output_attentions
69
+ )
70
+ output_hidden_states = (
71
+ output_hidden_states
72
+ if output_hidden_states is not None
73
+ else self.config.output_hidden_states
74
+ )
75
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
+
77
+ if self.config.is_decoder:
78
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
79
+ else:
80
+ use_cache = False
81
+
82
+ if input_ids is not None and inputs_embeds is not None:
83
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
84
+ elif input_ids is not None:
85
+ input_shape = input_ids.size()
86
+ batch_size, seq_length = input_shape
87
+ elif inputs_embeds is not None:
88
+ input_shape = inputs_embeds.size()[:-1]
89
+ batch_size, seq_length = input_shape
90
+ else:
91
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
92
+
93
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
94
+
95
+ # past_key_values_length
96
+ past_key_values_length = (
97
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
98
+ )
99
+
100
+ if attention_mask is None:
101
+ attention_mask = torch.ones(
102
+ ((batch_size, seq_length + past_key_values_length)), device=device
103
+ )
104
+ if token_type_ids is None:
105
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
106
+
107
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
108
+ # ourselves in which case we just need to make it broadcastable to all heads.
109
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
110
+ attention_mask, input_shape, device
111
+ )
112
+
113
+ # If a 2D or 3D attention mask is provided for the cross-attention
114
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
115
+ if self.config.is_decoder and encoder_hidden_states is not None:
116
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
117
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
118
+ if encoder_attention_mask is None:
119
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
120
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
121
+ else:
122
+ encoder_extended_attention_mask = None
123
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
124
+ # import ipdb; ipdb.set_trace()
125
+
126
+ # Prepare head mask if needed
127
+ # 1.0 in head_mask indicate we keep the head
128
+ # attention_probs has shape bsz x n_heads x N x N
129
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
130
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
131
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
132
+
133
+ embedding_output = self.embeddings(
134
+ input_ids=input_ids,
135
+ position_ids=position_ids,
136
+ token_type_ids=token_type_ids,
137
+ inputs_embeds=inputs_embeds,
138
+ past_key_values_length=past_key_values_length,
139
+ )
140
+
141
+ encoder_outputs = self.encoder(
142
+ embedding_output,
143
+ attention_mask=extended_attention_mask,
144
+ head_mask=head_mask,
145
+ encoder_hidden_states=encoder_hidden_states,
146
+ encoder_attention_mask=encoder_extended_attention_mask,
147
+ past_key_values=past_key_values,
148
+ use_cache=use_cache,
149
+ output_attentions=output_attentions,
150
+ output_hidden_states=output_hidden_states,
151
+ return_dict=return_dict,
152
+ )
153
+ sequence_output = encoder_outputs[0]
154
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
155
+
156
+ if not return_dict:
157
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
158
+
159
+ return BaseModelOutputWithPoolingAndCrossAttentions(
160
+ last_hidden_state=sequence_output,
161
+ pooler_output=pooled_output,
162
+ past_key_values=encoder_outputs.past_key_values,
163
+ hidden_states=encoder_outputs.hidden_states,
164
+ attentions=encoder_outputs.attentions,
165
+ cross_attentions=encoder_outputs.cross_attentions,
166
+ )
167
+
168
+
169
+ class TextEncoderShell(nn.Module):
170
+ def __init__(self, text_encoder):
171
+ super().__init__()
172
+ self.text_encoder = text_encoder
173
+ self.config = self.text_encoder.config
174
+
175
+ def forward(self, **kw):
176
+ # feed into text encoder
177
+ return self.text_encoder(**kw)
178
+
179
+
180
+ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
181
+ """Generate attention mask between each pair of special tokens
182
+ Args:
183
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
184
+ special_tokens_mask (list): special tokens mask.
185
+ Returns:
186
+ torch.Tensor: attention mask between each special tokens.
187
+ """
188
+ input_ids = tokenized["input_ids"]
189
+ bs, num_token = input_ids.shape
190
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
191
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
192
+ for special_token in special_tokens_list:
193
+ special_tokens_mask |= input_ids == special_token
194
+
195
+ # idxs: each row is a list of indices of special tokens
196
+ idxs = torch.nonzero(special_tokens_mask)
197
+
198
+ # generate attention mask and positional ids
199
+ attention_mask = (
200
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
201
+ )
202
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
203
+ previous_col = 0
204
+ for i in range(idxs.shape[0]):
205
+ row, col = idxs[i]
206
+ if (col == 0) or (col == num_token - 1):
207
+ attention_mask[row, col, col] = True
208
+ position_ids[row, col] = 0
209
+ else:
210
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
211
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
212
+ 0, col - previous_col, device=input_ids.device
213
+ )
214
+
215
+ previous_col = col
216
+
217
+ # # padding mask
218
+ # padding_mask = tokenized['attention_mask']
219
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
220
+
221
+ return attention_mask, position_ids.to(torch.long)
222
+
223
+
224
+ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
225
+ """Generate attention mask between each pair of special tokens
226
+ Args:
227
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
228
+ special_tokens_mask (list): special tokens mask.
229
+ Returns:
230
+ torch.Tensor: attention mask between each special tokens.
231
+ """
232
+ input_ids = tokenized["input_ids"]
233
+ bs, num_token = input_ids.shape
234
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
235
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
236
+ for special_token in special_tokens_list:
237
+ special_tokens_mask |= input_ids == special_token
238
+
239
+ # idxs: each row is a list of indices of special tokens
240
+ idxs = torch.nonzero(special_tokens_mask)
241
+
242
+ # generate attention mask and positional ids
243
+ attention_mask = (
244
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
245
+ )
246
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
247
+ cate_to_token_mask_list = [[] for _ in range(bs)]
248
+ previous_col = 0
249
+ for i in range(idxs.shape[0]):
250
+ row, col = idxs[i]
251
+ if (col == 0) or (col == num_token - 1):
252
+ attention_mask[row, col, col] = True
253
+ position_ids[row, col] = 0
254
+ else:
255
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
256
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
257
+ 0, col - previous_col, device=input_ids.device
258
+ )
259
+ c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
260
+ c2t_maski[previous_col + 1 : col] = True
261
+ cate_to_token_mask_list[row].append(c2t_maski)
262
+ previous_col = col
263
+
264
+ cate_to_token_mask_list = [
265
+ torch.stack(cate_to_token_mask_listi, dim=0)
266
+ for cate_to_token_mask_listi in cate_to_token_mask_list
267
+ ]
268
+
269
+ # # padding mask
270
+ # padding_mask = tokenized['attention_mask']
271
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
272
+
273
+ return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor
22
+ ms_deform_attn_forward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const int im2col_step)
29
+ {
30
+ if (value.type().is_cuda())
31
+ {
32
+ #ifdef WITH_CUDA
33
+ return ms_deform_attn_cuda_forward(
34
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
35
+ #else
36
+ AT_ERROR("Not compiled with GPU support");
37
+ #endif
38
+ }
39
+ AT_ERROR("Not implemented on the CPU");
40
+ }
41
+
42
+ std::vector<at::Tensor>
43
+ ms_deform_attn_backward(
44
+ const at::Tensor &value,
45
+ const at::Tensor &spatial_shapes,
46
+ const at::Tensor &level_start_index,
47
+ const at::Tensor &sampling_loc,
48
+ const at::Tensor &attn_weight,
49
+ const at::Tensor &grad_output,
50
+ const int im2col_step)
51
+ {
52
+ if (value.type().is_cuda())
53
+ {
54
+ #ifdef WITH_CUDA
55
+ return ms_deform_attn_cuda_backward(
56
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
57
+ #else
58
+ AT_ERROR("Not compiled with GPU support");
59
+ #endif
60
+ }
61
+ AT_ERROR("Not implemented on the CPU");
62
+ }
63
+
64
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+ namespace groundingdino {
17
+
18
+ at::Tensor
19
+ ms_deform_attn_cpu_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step)
26
+ {
27
+ AT_ERROR("Not implement on cpu");
28
+ }
29
+
30
+ std::vector<at::Tensor>
31
+ ms_deform_attn_cpu_backward(
32
+ const at::Tensor &value,
33
+ const at::Tensor &spatial_shapes,
34
+ const at::Tensor &level_start_index,
35
+ const at::Tensor &sampling_loc,
36
+ const at::Tensor &attn_weight,
37
+ const at::Tensor &grad_output,
38
+ const int im2col_step)
39
+ {
40
+ AT_ERROR("Not implement on cpu");
41
+ }
42
+
43
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ namespace groundingdino {
15
+
16
+ at::Tensor
17
+ ms_deform_attn_cpu_forward(
18
+ const at::Tensor &value,
19
+ const at::Tensor &spatial_shapes,
20
+ const at::Tensor &level_start_index,
21
+ const at::Tensor &sampling_loc,
22
+ const at::Tensor &attn_weight,
23
+ const int im2col_step);
24
+
25
+ std::vector<at::Tensor>
26
+ ms_deform_attn_cpu_backward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const at::Tensor &grad_output,
33
+ const int im2col_step);
34
+
35
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor ms_deform_attn_cuda_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
30
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
31
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
32
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
33
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
34
+
35
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
36
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
37
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
38
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
39
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
40
+
41
+ const int batch = value.size(0);
42
+ const int spatial_size = value.size(1);
43
+ const int num_heads = value.size(2);
44
+ const int channels = value.size(3);
45
+
46
+ const int num_levels = spatial_shapes.size(0);
47
+
48
+ const int num_query = sampling_loc.size(1);
49
+ const int num_point = sampling_loc.size(4);
50
+
51
+ const int im2col_step_ = std::min(batch, im2col_step);
52
+
53
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
54
+
55
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
56
+
57
+ const int batch_n = im2col_step_;
58
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
59
+ auto per_value_size = spatial_size * num_heads * channels;
60
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
61
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
62
+ for (int n = 0; n < batch/im2col_step_; ++n)
63
+ {
64
+ auto columns = output_n.select(0, n);
65
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
66
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
67
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
68
+ spatial_shapes.data<int64_t>(),
69
+ level_start_index.data<int64_t>(),
70
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
71
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
72
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
73
+ columns.data<scalar_t>());
74
+
75
+ }));
76
+ }
77
+
78
+ output = output.view({batch, num_query, num_heads*channels});
79
+
80
+ return output;
81
+ }
82
+
83
+
84
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
85
+ const at::Tensor &value,
86
+ const at::Tensor &spatial_shapes,
87
+ const at::Tensor &level_start_index,
88
+ const at::Tensor &sampling_loc,
89
+ const at::Tensor &attn_weight,
90
+ const at::Tensor &grad_output,
91
+ const int im2col_step)
92
+ {
93
+
94
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
95
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
96
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
97
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
98
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
99
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
100
+
101
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
102
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
103
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
104
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
105
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
106
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
107
+
108
+ const int batch = value.size(0);
109
+ const int spatial_size = value.size(1);
110
+ const int num_heads = value.size(2);
111
+ const int channels = value.size(3);
112
+
113
+ const int num_levels = spatial_shapes.size(0);
114
+
115
+ const int num_query = sampling_loc.size(1);
116
+ const int num_point = sampling_loc.size(4);
117
+
118
+ const int im2col_step_ = std::min(batch, im2col_step);
119
+
120
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
121
+
122
+ auto grad_value = at::zeros_like(value);
123
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
124
+ auto grad_attn_weight = at::zeros_like(attn_weight);
125
+
126
+ const int batch_n = im2col_step_;
127
+ auto per_value_size = spatial_size * num_heads * channels;
128
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
129
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
130
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
131
+
132
+ for (int n = 0; n < batch/im2col_step_; ++n)
133
+ {
134
+ auto grad_output_g = grad_output_n.select(0, n);
135
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
136
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
137
+ grad_output_g.data<scalar_t>(),
138
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
139
+ spatial_shapes.data<int64_t>(),
140
+ level_start_index.data<int64_t>(),
141
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
142
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
143
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
144
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
145
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
147
+
148
+ }));
149
+ }
150
+
151
+ return {
152
+ grad_value, grad_sampling_loc, grad_attn_weight
153
+ };
154
+ }
155
+
156
+ } // namespace groundingdino