YouLiXiya commited on
Commit
27d04f5
1 Parent(s): 14f2216

Upload 80 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. GroundingDINO/.asset/COCO.png +0 -0
  3. GroundingDINO/.asset/GD_GLIGEN.png +3 -0
  4. GroundingDINO/.asset/GD_SD.png +3 -0
  5. GroundingDINO/.asset/ODinW.png +0 -0
  6. GroundingDINO/.asset/arch.png +0 -0
  7. GroundingDINO/.asset/cats.png +0 -0
  8. GroundingDINO/.asset/hero_figure.png +3 -0
  9. GroundingDINO/LICENSE +201 -0
  10. GroundingDINO/README.md +163 -0
  11. GroundingDINO/demo/gradio_app.py +125 -0
  12. GroundingDINO/demo/inference_on_a_image.py +172 -0
  13. GroundingDINO/groundingdino/_C.cp38-win_amd64.pyd +0 -0
  14. GroundingDINO/groundingdino/__init__.py +0 -0
  15. GroundingDINO/groundingdino/__pycache__/__init__.cpython-38.pyc +0 -0
  16. GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py +43 -0
  17. GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py +43 -0
  18. GroundingDINO/groundingdino/datasets/__init__.py +0 -0
  19. GroundingDINO/groundingdino/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  20. GroundingDINO/groundingdino/datasets/__pycache__/transforms.cpython-38.pyc +0 -0
  21. GroundingDINO/groundingdino/datasets/transforms.py +311 -0
  22. GroundingDINO/groundingdino/models/GroundingDINO/__init__.py +15 -0
  23. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/__init__.cpython-38.pyc +0 -0
  24. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/bertwarper.cpython-38.pyc +0 -0
  25. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/fuse_modules.cpython-38.pyc +0 -0
  26. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/groundingdino.cpython-38.pyc +0 -0
  27. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/ms_deform_attn.cpython-38.pyc +0 -0
  28. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/transformer.cpython-38.pyc +0 -0
  29. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/transformer_vanilla.cpython-38.pyc +0 -0
  30. GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/utils.cpython-38.pyc +0 -0
  31. GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py +1 -0
  32. GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/__init__.cpython-38.pyc +0 -0
  33. GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/backbone.cpython-38.pyc +0 -0
  34. GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/position_encoding.cpython-38.pyc +0 -0
  35. GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/swin_transformer.cpython-38.pyc +0 -0
  36. GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py +221 -0
  37. GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py +186 -0
  38. GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py +802 -0
  39. GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py +273 -0
  40. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h +64 -0
  41. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp +43 -0
  42. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h +35 -0
  43. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +156 -0
  44. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h +33 -0
  45. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh +1327 -0
  46. GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu +7 -0
  47. GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp +58 -0
  48. GroundingDINO/groundingdino/models/GroundingDINO/fuse_modules.py +297 -0
  49. GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py +395 -0
  50. GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py +413 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ GroundingDINO/.asset/GD_GLIGEN.png filter=lfs diff=lfs merge=lfs -text
37
+ GroundingDINO/.asset/GD_SD.png filter=lfs diff=lfs merge=lfs -text
38
+ GroundingDINO/.asset/hero_figure.png filter=lfs diff=lfs merge=lfs -text
GroundingDINO/.asset/COCO.png ADDED
GroundingDINO/.asset/GD_GLIGEN.png ADDED

Git LFS Details

  • SHA256: 6e36d497ace68412ecd6c064fff6d7481a685963ffc2ec047a8892411fb0ab8e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
GroundingDINO/.asset/GD_SD.png ADDED

Git LFS Details

  • SHA256: 92c8a690a2de028d42c9b876c73dca53b7736134eb77cce5b3cbda9d1c4b62de
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
GroundingDINO/.asset/ODinW.png ADDED
GroundingDINO/.asset/arch.png ADDED
GroundingDINO/.asset/cats.png ADDED
GroundingDINO/.asset/hero_figure.png ADDED

Git LFS Details

  • SHA256: 24b18b31e9f150bae0ae01b09608d7bf7fc34f42c8e17d85eda55ea4a55b1e91
  • Pointer size: 132 Bytes
  • Size of remote file: 2.98 MB
GroundingDINO/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 - present, Facebook, Inc
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
GroundingDINO/README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Grounding DINO
2
+
3
+ ---
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2303.05499-b31b1b.svg)](https://arxiv.org/abs/2303.05499)
6
+ [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/wxWDt5UiwY8)
7
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
8
+ [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/cMa77r3YrDk)
9
+ [![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)
10
+
11
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
12
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
13
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
14
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
15
+
16
+
17
+
18
+ Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
19
+
20
+
21
+ ## Highlight
22
+
23
+ - **Open-Set Detection.** Detect **everything** with language!
24
+ - **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
25
+ - **Flexible.** Collaboration with Stable Diffusion for Image Editting.
26
+
27
+ ## News
28
+ [2023/03/28] A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)] \
29
+ [2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space! \
30
+ [2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
31
+ [2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)] \
32
+ [2023/03/22] Code is available Now!
33
+
34
+ <details open>
35
+ <summary><font size="4">
36
+ Description
37
+ </font></summary>
38
+ <img src=".asset/hero_figure.png" alt="ODinW" width="100%">
39
+ </details>
40
+
41
+
42
+
43
+ ## TODO
44
+
45
+ - [x] Release inference code and demo.
46
+ - [x] Release checkpoints.
47
+ - [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
48
+ - [ ] Release training codes.
49
+
50
+ ## Install
51
+
52
+ If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
53
+
54
+ ```bash
55
+ pip install -e .
56
+ ```
57
+
58
+ ## Demo
59
+
60
+ ```bash
61
+ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
62
+ -c /path/to/config \
63
+ -p /path/to/checkpoint \
64
+ -i .asset/cats.png \
65
+ -o "outputs/0" \
66
+ -t "cat ear." \
67
+ [--cpu-only] # open it for cpu mode
68
+ ```
69
+ See the `demo/inference_on_a_image.py` for more details.
70
+
71
+ **Web UI**
72
+
73
+ We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details.
74
+
75
+ ## Checkpoints
76
+
77
+ <!-- insert a table -->
78
+ <table>
79
+ <thead>
80
+ <tr style="text-align: right;">
81
+ <th></th>
82
+ <th>name</th>
83
+ <th>backbone</th>
84
+ <th>Data</th>
85
+ <th>box AP on COCO</th>
86
+ <th>Checkpoint</th>
87
+ <th>Config</th>
88
+ </tr>
89
+ </thead>
90
+ <tbody>
91
+ <tr>
92
+ <th>1</th>
93
+ <td>GroundingDINO-T</td>
94
+ <td>Swin-T</td>
95
+ <td>O365,GoldG,Cap4M</td>
96
+ <td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
97
+ <td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">Github link</a> | <a href="https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth">HF link</a></td>
98
+ <td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
99
+ </tr>
100
+ </tbody>
101
+ </table>
102
+
103
+ ## Results
104
+
105
+ <details open>
106
+ <summary><font size="4">
107
+ COCO Object Detection Results
108
+ </font></summary>
109
+ <img src=".asset/COCO.png" alt="COCO" width="100%">
110
+ </details>
111
+
112
+ <details open>
113
+ <summary><font size="4">
114
+ ODinW Object Detection Results
115
+ </font></summary>
116
+ <img src=".asset/ODinW.png" alt="ODinW" width="100%">
117
+ </details>
118
+
119
+ <details open>
120
+ <summary><font size="4">
121
+ Marrying Grounding DINO with <a href="https://github.com/Stability-AI/StableDiffusion">Stable Diffusion</a> for Image Editing
122
+ </font></summary>
123
+ <img src=".asset/GD_SD.png" alt="GD_SD" width="100%">
124
+ </details>
125
+
126
+ <details open>
127
+ <summary><font size="4">
128
+ Marrying Grounding DINO with <a href="https://github.com/gligen/GLIGEN">GLIGEN</a> for more Detailed Image Editing
129
+ </font></summary>
130
+ <img src=".asset/GD_GLIGEN.png" alt="GD_GLIGEN" width="100%">
131
+ </details>
132
+
133
+ ## Model
134
+
135
+ Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
136
+
137
+ ![arch](.asset/arch.png)
138
+
139
+
140
+ ## Acknowledgement
141
+
142
+ Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
143
+
144
+ We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
145
+
146
+ Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
147
+
148
+
149
+ ## Citation
150
+
151
+ If you find our work helpful for your research, please consider citing the following BibTeX entry.
152
+
153
+ ```bibtex
154
+ @inproceedings{ShilongLiu2023GroundingDM,
155
+ title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},
156
+ author={Shilong Liu and Zhaoyang Zeng and Tianhe Ren and Feng Li and Hao Zhang and Jie Yang and Chunyuan Li and Jianwei Yang and Hang Su and Jun Zhu and Lei Zhang},
157
+ year={2023}
158
+ }
159
+ ```
160
+
161
+
162
+
163
+
GroundingDINO/demo/gradio_app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from functools import partial
3
+ import cv2
4
+ import requests
5
+ import os
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import numpy as np
9
+ from pathlib import Path
10
+
11
+
12
+ import warnings
13
+
14
+ import torch
15
+
16
+ # prepare the environment
17
+ os.system("python setup.py build develop --user")
18
+ os.system("pip install packaging==21.3")
19
+ os.system("pip install gradio")
20
+
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+ import gradio as gr
25
+
26
+ from groundingdino.models import build_model
27
+ from groundingdino.util.slconfig import SLConfig
28
+ from groundingdino.util.utils import clean_state_dict
29
+ from groundingdino.util.inference import annotate, load_image, predict
30
+ import groundingdino.datasets.transforms as T
31
+
32
+ from huggingface_hub import hf_hub_download
33
+
34
+
35
+
36
+ # Use this command for evaluate the GLIP-T model
37
+ config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
38
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
39
+ ckpt_filenmae = "groundingdino_swint_ogc.pth"
40
+
41
+
42
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
43
+ args = SLConfig.fromfile(model_config_path)
44
+ model = build_model(args)
45
+ args.device = device
46
+
47
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
48
+ checkpoint = torch.load(cache_file, map_location='cpu')
49
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
50
+ print("Model loaded from {} \n => {}".format(cache_file, log))
51
+ _ = model.eval()
52
+ return model
53
+
54
+ def image_transform_grounding(init_image):
55
+ transform = T.Compose([
56
+ T.RandomResize([800], max_size=1333),
57
+ T.ToTensor(),
58
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
+ ])
60
+ image, _ = transform(init_image, None) # 3, h, w
61
+ return init_image, image
62
+
63
+ def image_transform_grounding_for_vis(init_image):
64
+ transform = T.Compose([
65
+ T.RandomResize([800], max_size=1333),
66
+ ])
67
+ image, _ = transform(init_image, None) # 3, h, w
68
+ return image
69
+
70
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
71
+
72
+ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
73
+ init_image = input_image.convert("RGB")
74
+ original_size = init_image.size
75
+
76
+ _, image_tensor = image_transform_grounding(init_image)
77
+ image_pil: Image = image_transform_grounding_for_vis(init_image)
78
+
79
+ # run grounidng
80
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
81
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
82
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
83
+
84
+
85
+ return image_with_box
86
+
87
+ if __name__ == "__main__":
88
+
89
+ parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
90
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
91
+ parser.add_argument("--share", action="store_true", help="share the app")
92
+ args = parser.parse_args()
93
+
94
+ block = gr.Blocks().queue()
95
+ with block:
96
+ gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
97
+ gr.Markdown("### Open-World Detection with Grounding DINO")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ input_image = gr.Image(source='upload', type="pil")
102
+ grounding_caption = gr.Textbox(label="Detection Prompt")
103
+ run_button = gr.Button(label="Run")
104
+ with gr.Accordion("Advanced options", open=False):
105
+ box_threshold = gr.Slider(
106
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
107
+ )
108
+ text_threshold = gr.Slider(
109
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
110
+ )
111
+
112
+ with gr.Column():
113
+ gallery = gr.outputs.Image(
114
+ type="pil",
115
+ # label="grounding results"
116
+ ).style(full_width=True, full_height=True)
117
+ # gallery = gr.Gallery(label="Generated images", show_label=False).style(
118
+ # grid=[1], height="auto", container=True, full_width=True, full_height=True)
119
+
120
+ run_button.click(fn=run_grounding, inputs=[
121
+ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
122
+
123
+
124
+ block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
125
+
GroundingDINO/demo/inference_on_a_image.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image, ImageDraw, ImageFont
8
+
9
+ import groundingdino.datasets.transforms as T
10
+ from groundingdino.models import build_model
11
+ from groundingdino.util import box_ops
12
+ from groundingdino.util.slconfig import SLConfig
13
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
14
+
15
+
16
+ def plot_boxes_to_image(image_pil, tgt):
17
+ H, W = tgt["size"]
18
+ boxes = tgt["boxes"]
19
+ labels = tgt["labels"]
20
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
21
+
22
+ draw = ImageDraw.Draw(image_pil)
23
+ mask = Image.new("L", image_pil.size, 0)
24
+ mask_draw = ImageDraw.Draw(mask)
25
+
26
+ # draw boxes and masks
27
+ for box, label in zip(boxes, labels):
28
+ # from 0..1 to 0..W, 0..H
29
+ box = box * torch.Tensor([W, H, W, H])
30
+ # from xywh to xyxy
31
+ box[:2] -= box[2:] / 2
32
+ box[2:] += box[:2]
33
+ # random color
34
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
35
+ # draw
36
+ x0, y0, x1, y1 = box
37
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
38
+
39
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
40
+ # draw.text((x0, y0), str(label), fill=color)
41
+
42
+ font = ImageFont.load_default()
43
+ if hasattr(font, "getbbox"):
44
+ bbox = draw.textbbox((x0, y0), str(label), font)
45
+ else:
46
+ w, h = draw.textsize(str(label), font)
47
+ bbox = (x0, y0, w + x0, y0 + h)
48
+ # bbox = draw.textbbox((x0, y0), str(label))
49
+ draw.rectangle(bbox, fill=color)
50
+ draw.text((x0, y0), str(label), fill="white")
51
+
52
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
53
+
54
+ return image_pil, mask
55
+
56
+
57
+ def load_image(image_path):
58
+ # load image
59
+ image_pil = Image.open(image_path).convert("RGB") # load image
60
+
61
+ transform = T.Compose(
62
+ [
63
+ T.RandomResize([800], max_size=1333),
64
+ T.ToTensor(),
65
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
66
+ ]
67
+ )
68
+ image, _ = transform(image_pil, None) # 3, h, w
69
+ return image_pil, image
70
+
71
+
72
+ def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
73
+ args = SLConfig.fromfile(model_config_path)
74
+ args.device = "cuda" if not cpu_only else "cpu"
75
+ model = build_model(args)
76
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
77
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
78
+ print(load_res)
79
+ _ = model.eval()
80
+ return model
81
+
82
+
83
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
84
+ caption = caption.lower()
85
+ caption = caption.strip()
86
+ if not caption.endswith("."):
87
+ caption = caption + "."
88
+ device = "cuda" if not cpu_only else "cpu"
89
+ model = model.to(device)
90
+ image = image.to(device)
91
+ with torch.no_grad():
92
+ outputs = model(image[None], captions=[caption])
93
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
94
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
95
+ logits.shape[0]
96
+
97
+ # filter output
98
+ logits_filt = logits.clone()
99
+ boxes_filt = boxes.clone()
100
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
101
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
102
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
103
+ logits_filt.shape[0]
104
+
105
+ # get phrase
106
+ tokenlizer = model.tokenizer
107
+ tokenized = tokenlizer(caption)
108
+ # build pred
109
+ pred_phrases = []
110
+ for logit, box in zip(logits_filt, boxes_filt):
111
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
112
+ if with_logits:
113
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
114
+ else:
115
+ pred_phrases.append(pred_phrase)
116
+
117
+ return boxes_filt, pred_phrases
118
+
119
+
120
+ if __name__ == "__main__":
121
+
122
+ parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
123
+ parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
124
+ parser.add_argument(
125
+ "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
126
+ )
127
+ parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
128
+ parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
129
+ parser.add_argument(
130
+ "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
131
+ )
132
+
133
+ parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
134
+ parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
135
+
136
+ parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
137
+ args = parser.parse_args()
138
+
139
+ # cfg
140
+ config_file = args.config_file # change the path of the model config file
141
+ checkpoint_path = args.checkpoint_path # change the path of the model
142
+ image_path = args.image_path
143
+ text_prompt = args.text_prompt
144
+ output_dir = args.output_dir
145
+ box_threshold = args.box_threshold
146
+ text_threshold = args.text_threshold
147
+
148
+ # make dir
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ # load image
151
+ image_pil, image = load_image(image_path)
152
+ # load model
153
+ model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
154
+
155
+ # visualize raw image
156
+ image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
157
+
158
+ # run model
159
+ boxes_filt, pred_phrases = get_grounding_output(
160
+ model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
161
+ )
162
+
163
+ # visualize pred
164
+ size = image_pil.size
165
+ pred_dict = {
166
+ "boxes": boxes_filt,
167
+ "size": [size[1], size[0]], # H,W
168
+ "labels": pred_phrases,
169
+ }
170
+ # import ipdb; ipdb.set_trace()
171
+ image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
172
+ image_with_box.save(os.path.join(output_dir, "pred.jpg"))
GroundingDINO/groundingdino/_C.cp38-win_amd64.pyd ADDED
Binary file (666 kB). View file
 
GroundingDINO/groundingdino/__init__.py ADDED
File without changes
GroundingDINO/groundingdino/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (155 Bytes). View file
 
GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_B_384_22k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_T_224_1k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
GroundingDINO/groundingdino/datasets/__init__.py ADDED
File without changes
GroundingDINO/groundingdino/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (164 Bytes). View file
 
GroundingDINO/groundingdino/datasets/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (10.3 kB). View file
 
GroundingDINO/groundingdino/datasets/transforms.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Transforms and data augmentation for both image + bbox.
4
+ """
5
+ import os
6
+ import random
7
+
8
+ import PIL
9
+ import torch
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as F
12
+
13
+ from groundingdino.util.box_ops import box_xyxy_to_cxcywh
14
+ from groundingdino.util.misc import interpolate
15
+
16
+
17
+ def crop(image, target, region):
18
+ cropped_image = F.crop(image, *region)
19
+
20
+ target = target.copy()
21
+ i, j, h, w = region
22
+
23
+ # should we do something wrt the original size?
24
+ target["size"] = torch.tensor([h, w])
25
+
26
+ fields = ["labels", "area", "iscrowd", "positive_map"]
27
+
28
+ if "boxes" in target:
29
+ boxes = target["boxes"]
30
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
31
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
32
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
33
+ cropped_boxes = cropped_boxes.clamp(min=0)
34
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
35
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
36
+ target["area"] = area
37
+ fields.append("boxes")
38
+
39
+ if "masks" in target:
40
+ # FIXME should we update the area here if there are no boxes?
41
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
42
+ fields.append("masks")
43
+
44
+ # remove elements for which the boxes or masks that have zero area
45
+ if "boxes" in target or "masks" in target:
46
+ # favor boxes selection when defining which elements to keep
47
+ # this is compatible with previous implementation
48
+ if "boxes" in target:
49
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
50
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
51
+ else:
52
+ keep = target["masks"].flatten(1).any(1)
53
+
54
+ for field in fields:
55
+ if field in target:
56
+ target[field] = target[field][keep]
57
+
58
+ if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
59
+ # for debug and visualization only.
60
+ if "strings_positive" in target:
61
+ target["strings_positive"] = [
62
+ _i for _i, _j in zip(target["strings_positive"], keep) if _j
63
+ ]
64
+
65
+ return cropped_image, target
66
+
67
+
68
+ def hflip(image, target):
69
+ flipped_image = F.hflip(image)
70
+
71
+ w, h = image.size
72
+
73
+ target = target.copy()
74
+ if "boxes" in target:
75
+ boxes = target["boxes"]
76
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
77
+ [w, 0, w, 0]
78
+ )
79
+ target["boxes"] = boxes
80
+
81
+ if "masks" in target:
82
+ target["masks"] = target["masks"].flip(-1)
83
+
84
+ return flipped_image, target
85
+
86
+
87
+ def resize(image, target, size, max_size=None):
88
+ # size can be min_size (scalar) or (w, h) tuple
89
+
90
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
91
+ w, h = image_size
92
+ if max_size is not None:
93
+ min_original_size = float(min((w, h)))
94
+ max_original_size = float(max((w, h)))
95
+ if max_original_size / min_original_size * size > max_size:
96
+ size = int(round(max_size * min_original_size / max_original_size))
97
+
98
+ if (w <= h and w == size) or (h <= w and h == size):
99
+ return (h, w)
100
+
101
+ if w < h:
102
+ ow = size
103
+ oh = int(size * h / w)
104
+ else:
105
+ oh = size
106
+ ow = int(size * w / h)
107
+
108
+ return (oh, ow)
109
+
110
+ def get_size(image_size, size, max_size=None):
111
+ if isinstance(size, (list, tuple)):
112
+ return size[::-1]
113
+ else:
114
+ return get_size_with_aspect_ratio(image_size, size, max_size)
115
+
116
+ size = get_size(image.size, size, max_size)
117
+ rescaled_image = F.resize(image, size)
118
+
119
+ if target is None:
120
+ return rescaled_image, None
121
+
122
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
123
+ ratio_width, ratio_height = ratios
124
+
125
+ target = target.copy()
126
+ if "boxes" in target:
127
+ boxes = target["boxes"]
128
+ scaled_boxes = boxes * torch.as_tensor(
129
+ [ratio_width, ratio_height, ratio_width, ratio_height]
130
+ )
131
+ target["boxes"] = scaled_boxes
132
+
133
+ if "area" in target:
134
+ area = target["area"]
135
+ scaled_area = area * (ratio_width * ratio_height)
136
+ target["area"] = scaled_area
137
+
138
+ h, w = size
139
+ target["size"] = torch.tensor([h, w])
140
+
141
+ if "masks" in target:
142
+ target["masks"] = (
143
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
144
+ )
145
+
146
+ return rescaled_image, target
147
+
148
+
149
+ def pad(image, target, padding):
150
+ # assumes that we only pad on the bottom right corners
151
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
152
+ if target is None:
153
+ return padded_image, None
154
+ target = target.copy()
155
+ # should we do something wrt the original size?
156
+ target["size"] = torch.tensor(padded_image.size[::-1])
157
+ if "masks" in target:
158
+ target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
159
+ return padded_image, target
160
+
161
+
162
+ class ResizeDebug(object):
163
+ def __init__(self, size):
164
+ self.size = size
165
+
166
+ def __call__(self, img, target):
167
+ return resize(img, target, self.size)
168
+
169
+
170
+ class RandomCrop(object):
171
+ def __init__(self, size):
172
+ self.size = size
173
+
174
+ def __call__(self, img, target):
175
+ region = T.RandomCrop.get_params(img, self.size)
176
+ return crop(img, target, region)
177
+
178
+
179
+ class RandomSizeCrop(object):
180
+ def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
181
+ # respect_boxes: True to keep all boxes
182
+ # False to tolerence box filter
183
+ self.min_size = min_size
184
+ self.max_size = max_size
185
+ self.respect_boxes = respect_boxes
186
+
187
+ def __call__(self, img: PIL.Image.Image, target: dict):
188
+ init_boxes = len(target["boxes"])
189
+ max_patience = 10
190
+ for i in range(max_patience):
191
+ w = random.randint(self.min_size, min(img.width, self.max_size))
192
+ h = random.randint(self.min_size, min(img.height, self.max_size))
193
+ region = T.RandomCrop.get_params(img, [h, w])
194
+ result_img, result_target = crop(img, target, region)
195
+ if (
196
+ not self.respect_boxes
197
+ or len(result_target["boxes"]) == init_boxes
198
+ or i == max_patience - 1
199
+ ):
200
+ return result_img, result_target
201
+ return result_img, result_target
202
+
203
+
204
+ class CenterCrop(object):
205
+ def __init__(self, size):
206
+ self.size = size
207
+
208
+ def __call__(self, img, target):
209
+ image_width, image_height = img.size
210
+ crop_height, crop_width = self.size
211
+ crop_top = int(round((image_height - crop_height) / 2.0))
212
+ crop_left = int(round((image_width - crop_width) / 2.0))
213
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
214
+
215
+
216
+ class RandomHorizontalFlip(object):
217
+ def __init__(self, p=0.5):
218
+ self.p = p
219
+
220
+ def __call__(self, img, target):
221
+ if random.random() < self.p:
222
+ return hflip(img, target)
223
+ return img, target
224
+
225
+
226
+ class RandomResize(object):
227
+ def __init__(self, sizes, max_size=None):
228
+ assert isinstance(sizes, (list, tuple))
229
+ self.sizes = sizes
230
+ self.max_size = max_size
231
+
232
+ def __call__(self, img, target=None):
233
+ size = random.choice(self.sizes)
234
+ return resize(img, target, size, self.max_size)
235
+
236
+
237
+ class RandomPad(object):
238
+ def __init__(self, max_pad):
239
+ self.max_pad = max_pad
240
+
241
+ def __call__(self, img, target):
242
+ pad_x = random.randint(0, self.max_pad)
243
+ pad_y = random.randint(0, self.max_pad)
244
+ return pad(img, target, (pad_x, pad_y))
245
+
246
+
247
+ class RandomSelect(object):
248
+ """
249
+ Randomly selects between transforms1 and transforms2,
250
+ with probability p for transforms1 and (1 - p) for transforms2
251
+ """
252
+
253
+ def __init__(self, transforms1, transforms2, p=0.5):
254
+ self.transforms1 = transforms1
255
+ self.transforms2 = transforms2
256
+ self.p = p
257
+
258
+ def __call__(self, img, target):
259
+ if random.random() < self.p:
260
+ return self.transforms1(img, target)
261
+ return self.transforms2(img, target)
262
+
263
+
264
+ class ToTensor(object):
265
+ def __call__(self, img, target):
266
+ return F.to_tensor(img), target
267
+
268
+
269
+ class RandomErasing(object):
270
+ def __init__(self, *args, **kwargs):
271
+ self.eraser = T.RandomErasing(*args, **kwargs)
272
+
273
+ def __call__(self, img, target):
274
+ return self.eraser(img), target
275
+
276
+
277
+ class Normalize(object):
278
+ def __init__(self, mean, std):
279
+ self.mean = mean
280
+ self.std = std
281
+
282
+ def __call__(self, image, target=None):
283
+ image = F.normalize(image, mean=self.mean, std=self.std)
284
+ if target is None:
285
+ return image, None
286
+ target = target.copy()
287
+ h, w = image.shape[-2:]
288
+ if "boxes" in target:
289
+ boxes = target["boxes"]
290
+ boxes = box_xyxy_to_cxcywh(boxes)
291
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
292
+ target["boxes"] = boxes
293
+ return image, target
294
+
295
+
296
+ class Compose(object):
297
+ def __init__(self, transforms):
298
+ self.transforms = transforms
299
+
300
+ def __call__(self, image, target):
301
+ for t in self.transforms:
302
+ image, target = t(image, target)
303
+ return image, target
304
+
305
+ def __repr__(self):
306
+ format_string = self.__class__.__name__ + "("
307
+ for t in self.transforms:
308
+ format_string += "\n"
309
+ format_string += " {0}".format(t)
310
+ format_string += "\n)"
311
+ return format_string
GroundingDINO/groundingdino/models/GroundingDINO/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ from .groundingdino import build_groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (233 Bytes). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/bertwarper.cpython-38.pyc ADDED
Binary file (7.19 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/fuse_modules.cpython-38.pyc ADDED
Binary file (7.8 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/groundingdino.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/ms_deform_attn.cpython-38.pyc ADDED
Binary file (11.8 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (18.9 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/transformer_vanilla.cpython-38.pyc ADDED
Binary file (3.42 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/__pycache__/utils.cpython-38.pyc ADDED
Binary file (9.56 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .backbone import build_backbone
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (232 Bytes). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/backbone.cpython-38.pyc ADDED
Binary file (6.24 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/position_encoding.cpython-38.pyc ADDED
Binary file (5.18 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__pycache__/swin_transformer.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ """
16
+ Backbone modules.
17
+ """
18
+
19
+ from typing import Dict, List
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchvision
24
+ from torch import nn
25
+ from torchvision.models._utils import IntermediateLayerGetter
26
+
27
+ from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
28
+
29
+ from .position_encoding import build_position_encoding
30
+ from .swin_transformer import build_swin_transformer
31
+
32
+
33
+ class FrozenBatchNorm2d(torch.nn.Module):
34
+ """
35
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
36
+
37
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
38
+ without which any other models than torchvision.models.resnet[18,34,50,101]
39
+ produce nans.
40
+ """
41
+
42
+ def __init__(self, n):
43
+ super(FrozenBatchNorm2d, self).__init__()
44
+ self.register_buffer("weight", torch.ones(n))
45
+ self.register_buffer("bias", torch.zeros(n))
46
+ self.register_buffer("running_mean", torch.zeros(n))
47
+ self.register_buffer("running_var", torch.ones(n))
48
+
49
+ def _load_from_state_dict(
50
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
51
+ ):
52
+ num_batches_tracked_key = prefix + "num_batches_tracked"
53
+ if num_batches_tracked_key in state_dict:
54
+ del state_dict[num_batches_tracked_key]
55
+
56
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
57
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
58
+ )
59
+
60
+ def forward(self, x):
61
+ # move reshapes to the beginning
62
+ # to make it fuser-friendly
63
+ w = self.weight.reshape(1, -1, 1, 1)
64
+ b = self.bias.reshape(1, -1, 1, 1)
65
+ rv = self.running_var.reshape(1, -1, 1, 1)
66
+ rm = self.running_mean.reshape(1, -1, 1, 1)
67
+ eps = 1e-5
68
+ scale = w * (rv + eps).rsqrt()
69
+ bias = b - rm * scale
70
+ return x * scale + bias
71
+
72
+
73
+ class BackboneBase(nn.Module):
74
+ def __init__(
75
+ self,
76
+ backbone: nn.Module,
77
+ train_backbone: bool,
78
+ num_channels: int,
79
+ return_interm_indices: list,
80
+ ):
81
+ super().__init__()
82
+ for name, parameter in backbone.named_parameters():
83
+ if (
84
+ not train_backbone
85
+ or "layer2" not in name
86
+ and "layer3" not in name
87
+ and "layer4" not in name
88
+ ):
89
+ parameter.requires_grad_(False)
90
+
91
+ return_layers = {}
92
+ for idx, layer_index in enumerate(return_interm_indices):
93
+ return_layers.update(
94
+ {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
95
+ )
96
+
97
+ # if len:
98
+ # if use_stage1_feature:
99
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
100
+ # else:
101
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
102
+ # else:
103
+ # return_layers = {'layer4': "0"}
104
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
105
+ self.num_channels = num_channels
106
+
107
+ def forward(self, tensor_list: NestedTensor):
108
+ xs = self.body(tensor_list.tensors)
109
+ out: Dict[str, NestedTensor] = {}
110
+ for name, x in xs.items():
111
+ m = tensor_list.mask
112
+ assert m is not None
113
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
114
+ out[name] = NestedTensor(x, mask)
115
+ # import ipdb; ipdb.set_trace()
116
+ return out
117
+
118
+
119
+ class Backbone(BackboneBase):
120
+ """ResNet backbone with frozen BatchNorm."""
121
+
122
+ def __init__(
123
+ self,
124
+ name: str,
125
+ train_backbone: bool,
126
+ dilation: bool,
127
+ return_interm_indices: list,
128
+ batch_norm=FrozenBatchNorm2d,
129
+ ):
130
+ if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
131
+ backbone = getattr(torchvision.models, name)(
132
+ replace_stride_with_dilation=[False, False, dilation],
133
+ pretrained=is_main_process(),
134
+ norm_layer=batch_norm,
135
+ )
136
+ else:
137
+ raise NotImplementedError("Why you can get here with name {}".format(name))
138
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
139
+ assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
140
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
141
+ num_channels_all = [256, 512, 1024, 2048]
142
+ num_channels = num_channels_all[4 - len(return_interm_indices) :]
143
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
144
+
145
+
146
+ class Joiner(nn.Sequential):
147
+ def __init__(self, backbone, position_embedding):
148
+ super().__init__(backbone, position_embedding)
149
+
150
+ def forward(self, tensor_list: NestedTensor):
151
+ xs = self[0](tensor_list)
152
+ out: List[NestedTensor] = []
153
+ pos = []
154
+ for name, x in xs.items():
155
+ out.append(x)
156
+ # position encoding
157
+ pos.append(self[1](x).to(x.tensors.dtype))
158
+
159
+ return out, pos
160
+
161
+
162
+ def build_backbone(args):
163
+ """
164
+ Useful args:
165
+ - backbone: backbone name
166
+ - lr_backbone:
167
+ - dilation
168
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
169
+ - backbone_freeze_keywords:
170
+ - use_checkpoint: for swin only for now
171
+
172
+ """
173
+ position_embedding = build_position_encoding(args)
174
+ train_backbone = True
175
+ if not train_backbone:
176
+ raise ValueError("Please set lr_backbone > 0")
177
+ return_interm_indices = args.return_interm_indices
178
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
179
+ args.backbone_freeze_keywords
180
+ use_checkpoint = getattr(args, "use_checkpoint", False)
181
+
182
+ if args.backbone in ["resnet50", "resnet101"]:
183
+ backbone = Backbone(
184
+ args.backbone,
185
+ train_backbone,
186
+ args.dilation,
187
+ return_interm_indices,
188
+ batch_norm=FrozenBatchNorm2d,
189
+ )
190
+ bb_num_channels = backbone.num_channels
191
+ elif args.backbone in [
192
+ "swin_T_224_1k",
193
+ "swin_B_224_22k",
194
+ "swin_B_384_22k",
195
+ "swin_L_224_22k",
196
+ "swin_L_384_22k",
197
+ ]:
198
+ pretrain_img_size = int(args.backbone.split("_")[-2])
199
+ backbone = build_swin_transformer(
200
+ args.backbone,
201
+ pretrain_img_size=pretrain_img_size,
202
+ out_indices=tuple(return_interm_indices),
203
+ dilation=False,
204
+ use_checkpoint=use_checkpoint,
205
+ )
206
+
207
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
208
+ else:
209
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
210
+
211
+ assert len(bb_num_channels) == len(
212
+ return_interm_indices
213
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
214
+
215
+ model = Joiner(backbone, position_embedding)
216
+ model.num_channels = bb_num_channels
217
+ assert isinstance(
218
+ bb_num_channels, List
219
+ ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
220
+ # import ipdb; ipdb.set_trace()
221
+ return model
GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Copied from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ """
20
+ Various positional encodings for the transformer.
21
+ """
22
+ import math
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from groundingdino.util.misc import NestedTensor
28
+
29
+
30
+ class PositionEmbeddingSine(nn.Module):
31
+ """
32
+ This is a more standard version of the position embedding, very similar to the one
33
+ used by the Attention is all you need paper, generalized to work on images.
34
+ """
35
+
36
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
37
+ super().__init__()
38
+ self.num_pos_feats = num_pos_feats
39
+ self.temperature = temperature
40
+ self.normalize = normalize
41
+ if scale is not None and normalize is False:
42
+ raise ValueError("normalize should be True if scale is passed")
43
+ if scale is None:
44
+ scale = 2 * math.pi
45
+ self.scale = scale
46
+
47
+ def forward(self, tensor_list: NestedTensor):
48
+ x = tensor_list.tensors
49
+ mask = tensor_list.mask
50
+ assert mask is not None
51
+ not_mask = ~mask
52
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
53
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
54
+ if self.normalize:
55
+ eps = 1e-6
56
+ # if os.environ.get("SHILONG_AMP", None) == '1':
57
+ # eps = 1e-4
58
+ # else:
59
+ # eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ return pos
76
+
77
+
78
+ class PositionEmbeddingSineHW(nn.Module):
79
+ """
80
+ This is a more standard version of the position embedding, very similar to the one
81
+ used by the Attention is all you need paper, generalized to work on images.
82
+ """
83
+
84
+ def __init__(
85
+ self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
86
+ ):
87
+ super().__init__()
88
+ self.num_pos_feats = num_pos_feats
89
+ self.temperatureH = temperatureH
90
+ self.temperatureW = temperatureW
91
+ self.normalize = normalize
92
+ if scale is not None and normalize is False:
93
+ raise ValueError("normalize should be True if scale is passed")
94
+ if scale is None:
95
+ scale = 2 * math.pi
96
+ self.scale = scale
97
+
98
+ def forward(self, tensor_list: NestedTensor):
99
+ x = tensor_list.tensors
100
+ mask = tensor_list.mask
101
+ assert mask is not None
102
+ not_mask = ~mask
103
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
104
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
105
+
106
+ # import ipdb; ipdb.set_trace()
107
+
108
+ if self.normalize:
109
+ eps = 1e-6
110
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
111
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
112
+
113
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
114
+ dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
115
+ pos_x = x_embed[:, :, :, None] / dim_tx
116
+
117
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
118
+ dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
119
+ pos_y = y_embed[:, :, :, None] / dim_ty
120
+
121
+ pos_x = torch.stack(
122
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
123
+ ).flatten(3)
124
+ pos_y = torch.stack(
125
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
126
+ ).flatten(3)
127
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
128
+
129
+ # import ipdb; ipdb.set_trace()
130
+
131
+ return pos
132
+
133
+
134
+ class PositionEmbeddingLearned(nn.Module):
135
+ """
136
+ Absolute pos embedding, learned.
137
+ """
138
+
139
+ def __init__(self, num_pos_feats=256):
140
+ super().__init__()
141
+ self.row_embed = nn.Embedding(50, num_pos_feats)
142
+ self.col_embed = nn.Embedding(50, num_pos_feats)
143
+ self.reset_parameters()
144
+
145
+ def reset_parameters(self):
146
+ nn.init.uniform_(self.row_embed.weight)
147
+ nn.init.uniform_(self.col_embed.weight)
148
+
149
+ def forward(self, tensor_list: NestedTensor):
150
+ x = tensor_list.tensors
151
+ h, w = x.shape[-2:]
152
+ i = torch.arange(w, device=x.device)
153
+ j = torch.arange(h, device=x.device)
154
+ x_emb = self.col_embed(i)
155
+ y_emb = self.row_embed(j)
156
+ pos = (
157
+ torch.cat(
158
+ [
159
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
160
+ y_emb.unsqueeze(1).repeat(1, w, 1),
161
+ ],
162
+ dim=-1,
163
+ )
164
+ .permute(2, 0, 1)
165
+ .unsqueeze(0)
166
+ .repeat(x.shape[0], 1, 1, 1)
167
+ )
168
+ return pos
169
+
170
+
171
+ def build_position_encoding(args):
172
+ N_steps = args.hidden_dim // 2
173
+ if args.position_embedding in ("v2", "sine"):
174
+ # TODO find a better way of exposing other arguments
175
+ position_embedding = PositionEmbeddingSineHW(
176
+ N_steps,
177
+ temperatureH=args.pe_temperatureH,
178
+ temperatureW=args.pe_temperatureW,
179
+ normalize=True,
180
+ )
181
+ elif args.position_embedding in ("v3", "learned"):
182
+ position_embedding = PositionEmbeddingLearned(N_steps)
183
+ else:
184
+ raise ValueError(f"not supported {args.position_embedding}")
185
+
186
+ return position_embedding
GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # --------------------------------------------------------
11
+ # modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
12
+ # --------------------------------------------------------
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.utils.checkpoint as checkpoint
19
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
20
+
21
+ from groundingdino.util.misc import NestedTensor
22
+
23
+
24
+ class Mlp(nn.Module):
25
+ """Multilayer perceptron."""
26
+
27
+ def __init__(
28
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
29
+ ):
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ self.fc1 = nn.Linear(in_features, hidden_features)
34
+ self.act = act_layer()
35
+ self.fc2 = nn.Linear(hidden_features, out_features)
36
+ self.drop = nn.Dropout(drop)
37
+
38
+ def forward(self, x):
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop(x)
42
+ x = self.fc2(x)
43
+ x = self.drop(x)
44
+ return x
45
+
46
+
47
+ def window_partition(x, window_size):
48
+ """
49
+ Args:
50
+ x: (B, H, W, C)
51
+ window_size (int): window size
52
+ Returns:
53
+ windows: (num_windows*B, window_size, window_size, C)
54
+ """
55
+ B, H, W, C = x.shape
56
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
57
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
58
+ return windows
59
+
60
+
61
+ def window_reverse(windows, window_size, H, W):
62
+ """
63
+ Args:
64
+ windows: (num_windows*B, window_size, window_size, C)
65
+ window_size (int): Window size
66
+ H (int): Height of image
67
+ W (int): Width of image
68
+ Returns:
69
+ x: (B, H, W, C)
70
+ """
71
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
72
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
73
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
74
+ return x
75
+
76
+
77
+ class WindowAttention(nn.Module):
78
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
79
+ It supports both of shifted and non-shifted window.
80
+ Args:
81
+ dim (int): Number of input channels.
82
+ window_size (tuple[int]): The height and width of the window.
83
+ num_heads (int): Number of attention heads.
84
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
85
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
86
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
87
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ dim,
93
+ window_size,
94
+ num_heads,
95
+ qkv_bias=True,
96
+ qk_scale=None,
97
+ attn_drop=0.0,
98
+ proj_drop=0.0,
99
+ ):
100
+
101
+ super().__init__()
102
+ self.dim = dim
103
+ self.window_size = window_size # Wh, Ww
104
+ self.num_heads = num_heads
105
+ head_dim = dim // num_heads
106
+ self.scale = qk_scale or head_dim**-0.5
107
+
108
+ # define a parameter table of relative position bias
109
+ self.relative_position_bias_table = nn.Parameter(
110
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
111
+ ) # 2*Wh-1 * 2*Ww-1, nH
112
+
113
+ # get pair-wise relative position index for each token inside the window
114
+ coords_h = torch.arange(self.window_size[0])
115
+ coords_w = torch.arange(self.window_size[1])
116
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
117
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
118
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
119
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
120
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
121
+ relative_coords[:, :, 1] += self.window_size[1] - 1
122
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
123
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
124
+ self.register_buffer("relative_position_index", relative_position_index)
125
+
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.attn_drop = nn.Dropout(attn_drop)
128
+ self.proj = nn.Linear(dim, dim)
129
+ self.proj_drop = nn.Dropout(proj_drop)
130
+
131
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
132
+ self.softmax = nn.Softmax(dim=-1)
133
+
134
+ def forward(self, x, mask=None):
135
+ """Forward function.
136
+ Args:
137
+ x: input features with shape of (num_windows*B, N, C)
138
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
139
+ """
140
+ B_, N, C = x.shape
141
+ qkv = (
142
+ self.qkv(x)
143
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
144
+ .permute(2, 0, 3, 1, 4)
145
+ )
146
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
147
+
148
+ q = q * self.scale
149
+ attn = q @ k.transpose(-2, -1)
150
+
151
+ relative_position_bias = self.relative_position_bias_table[
152
+ self.relative_position_index.view(-1)
153
+ ].view(
154
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
155
+ ) # Wh*Ww,Wh*Ww,nH
156
+ relative_position_bias = relative_position_bias.permute(
157
+ 2, 0, 1
158
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
159
+ attn = attn + relative_position_bias.unsqueeze(0)
160
+
161
+ if mask is not None:
162
+ nW = mask.shape[0]
163
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
+ attn = attn.view(-1, self.num_heads, N, N)
165
+ attn = self.softmax(attn)
166
+ else:
167
+ attn = self.softmax(attn)
168
+
169
+ attn = self.attn_drop(attn)
170
+
171
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class SwinTransformerBlock(nn.Module):
178
+ """Swin Transformer Block.
179
+ Args:
180
+ dim (int): Number of input channels.
181
+ num_heads (int): Number of attention heads.
182
+ window_size (int): Window size.
183
+ shift_size (int): Shift size for SW-MSA.
184
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
185
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
186
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
187
+ drop (float, optional): Dropout rate. Default: 0.0
188
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
189
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
190
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
191
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ dim,
197
+ num_heads,
198
+ window_size=7,
199
+ shift_size=0,
200
+ mlp_ratio=4.0,
201
+ qkv_bias=True,
202
+ qk_scale=None,
203
+ drop=0.0,
204
+ attn_drop=0.0,
205
+ drop_path=0.0,
206
+ act_layer=nn.GELU,
207
+ norm_layer=nn.LayerNorm,
208
+ ):
209
+ super().__init__()
210
+ self.dim = dim
211
+ self.num_heads = num_heads
212
+ self.window_size = window_size
213
+ self.shift_size = shift_size
214
+ self.mlp_ratio = mlp_ratio
215
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
216
+
217
+ self.norm1 = norm_layer(dim)
218
+ self.attn = WindowAttention(
219
+ dim,
220
+ window_size=to_2tuple(self.window_size),
221
+ num_heads=num_heads,
222
+ qkv_bias=qkv_bias,
223
+ qk_scale=qk_scale,
224
+ attn_drop=attn_drop,
225
+ proj_drop=drop,
226
+ )
227
+
228
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
229
+ self.norm2 = norm_layer(dim)
230
+ mlp_hidden_dim = int(dim * mlp_ratio)
231
+ self.mlp = Mlp(
232
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
233
+ )
234
+
235
+ self.H = None
236
+ self.W = None
237
+
238
+ def forward(self, x, mask_matrix):
239
+ """Forward function.
240
+ Args:
241
+ x: Input feature, tensor size (B, H*W, C).
242
+ H, W: Spatial resolution of the input feature.
243
+ mask_matrix: Attention mask for cyclic shift.
244
+ """
245
+ B, L, C = x.shape
246
+ H, W = self.H, self.W
247
+ assert L == H * W, "input feature has wrong size"
248
+
249
+ shortcut = x
250
+ x = self.norm1(x)
251
+ x = x.view(B, H, W, C)
252
+
253
+ # pad feature maps to multiples of window size
254
+ pad_l = pad_t = 0
255
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
256
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
257
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
258
+ _, Hp, Wp, _ = x.shape
259
+
260
+ # cyclic shift
261
+ if self.shift_size > 0:
262
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
263
+ attn_mask = mask_matrix
264
+ else:
265
+ shifted_x = x
266
+ attn_mask = None
267
+
268
+ # partition windows
269
+ x_windows = window_partition(
270
+ shifted_x, self.window_size
271
+ ) # nW*B, window_size, window_size, C
272
+ x_windows = x_windows.view(
273
+ -1, self.window_size * self.window_size, C
274
+ ) # nW*B, window_size*window_size, C
275
+
276
+ # W-MSA/SW-MSA
277
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
278
+
279
+ # merge windows
280
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
281
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
282
+
283
+ # reverse cyclic shift
284
+ if self.shift_size > 0:
285
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
286
+ else:
287
+ x = shifted_x
288
+
289
+ if pad_r > 0 or pad_b > 0:
290
+ x = x[:, :H, :W, :].contiguous()
291
+
292
+ x = x.view(B, H * W, C)
293
+
294
+ # FFN
295
+ x = shortcut + self.drop_path(x)
296
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
297
+
298
+ return x
299
+
300
+
301
+ class PatchMerging(nn.Module):
302
+ """Patch Merging Layer
303
+ Args:
304
+ dim (int): Number of input channels.
305
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
306
+ """
307
+
308
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
309
+ super().__init__()
310
+ self.dim = dim
311
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
312
+ self.norm = norm_layer(4 * dim)
313
+
314
+ def forward(self, x, H, W):
315
+ """Forward function.
316
+ Args:
317
+ x: Input feature, tensor size (B, H*W, C).
318
+ H, W: Spatial resolution of the input feature.
319
+ """
320
+ B, L, C = x.shape
321
+ assert L == H * W, "input feature has wrong size"
322
+
323
+ x = x.view(B, H, W, C)
324
+
325
+ # padding
326
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
327
+ if pad_input:
328
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
329
+
330
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
331
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
332
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
333
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
334
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
335
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
336
+
337
+ x = self.norm(x)
338
+ x = self.reduction(x)
339
+
340
+ return x
341
+
342
+
343
+ class BasicLayer(nn.Module):
344
+ """A basic Swin Transformer layer for one stage.
345
+ Args:
346
+ dim (int): Number of feature channels
347
+ depth (int): Depths of this stage.
348
+ num_heads (int): Number of attention head.
349
+ window_size (int): Local window size. Default: 7.
350
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
351
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
352
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
353
+ drop (float, optional): Dropout rate. Default: 0.0
354
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
355
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
356
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
357
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
358
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ dim,
364
+ depth,
365
+ num_heads,
366
+ window_size=7,
367
+ mlp_ratio=4.0,
368
+ qkv_bias=True,
369
+ qk_scale=None,
370
+ drop=0.0,
371
+ attn_drop=0.0,
372
+ drop_path=0.0,
373
+ norm_layer=nn.LayerNorm,
374
+ downsample=None,
375
+ use_checkpoint=False,
376
+ ):
377
+ super().__init__()
378
+ self.window_size = window_size
379
+ self.shift_size = window_size // 2
380
+ self.depth = depth
381
+ self.use_checkpoint = use_checkpoint
382
+
383
+ # build blocks
384
+ self.blocks = nn.ModuleList(
385
+ [
386
+ SwinTransformerBlock(
387
+ dim=dim,
388
+ num_heads=num_heads,
389
+ window_size=window_size,
390
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
391
+ mlp_ratio=mlp_ratio,
392
+ qkv_bias=qkv_bias,
393
+ qk_scale=qk_scale,
394
+ drop=drop,
395
+ attn_drop=attn_drop,
396
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
397
+ norm_layer=norm_layer,
398
+ )
399
+ for i in range(depth)
400
+ ]
401
+ )
402
+
403
+ # patch merging layer
404
+ if downsample is not None:
405
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
406
+ else:
407
+ self.downsample = None
408
+
409
+ def forward(self, x, H, W):
410
+ """Forward function.
411
+ Args:
412
+ x: Input feature, tensor size (B, H*W, C).
413
+ H, W: Spatial resolution of the input feature.
414
+ """
415
+
416
+ # calculate attention mask for SW-MSA
417
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
418
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
419
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
420
+ h_slices = (
421
+ slice(0, -self.window_size),
422
+ slice(-self.window_size, -self.shift_size),
423
+ slice(-self.shift_size, None),
424
+ )
425
+ w_slices = (
426
+ slice(0, -self.window_size),
427
+ slice(-self.window_size, -self.shift_size),
428
+ slice(-self.shift_size, None),
429
+ )
430
+ cnt = 0
431
+ for h in h_slices:
432
+ for w in w_slices:
433
+ img_mask[:, h, w, :] = cnt
434
+ cnt += 1
435
+
436
+ mask_windows = window_partition(
437
+ img_mask, self.window_size
438
+ ) # nW, window_size, window_size, 1
439
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
440
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
441
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
442
+ attn_mask == 0, float(0.0)
443
+ )
444
+
445
+ for blk in self.blocks:
446
+ blk.H, blk.W = H, W
447
+ if self.use_checkpoint:
448
+ x = checkpoint.checkpoint(blk, x, attn_mask)
449
+ else:
450
+ x = blk(x, attn_mask)
451
+ if self.downsample is not None:
452
+ x_down = self.downsample(x, H, W)
453
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
454
+ return x, H, W, x_down, Wh, Ww
455
+ else:
456
+ return x, H, W, x, H, W
457
+
458
+
459
+ class PatchEmbed(nn.Module):
460
+ """Image to Patch Embedding
461
+ Args:
462
+ patch_size (int): Patch token size. Default: 4.
463
+ in_chans (int): Number of input image channels. Default: 3.
464
+ embed_dim (int): Number of linear projection output channels. Default: 96.
465
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
466
+ """
467
+
468
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
469
+ super().__init__()
470
+ patch_size = to_2tuple(patch_size)
471
+ self.patch_size = patch_size
472
+
473
+ self.in_chans = in_chans
474
+ self.embed_dim = embed_dim
475
+
476
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
477
+ if norm_layer is not None:
478
+ self.norm = norm_layer(embed_dim)
479
+ else:
480
+ self.norm = None
481
+
482
+ def forward(self, x):
483
+ """Forward function."""
484
+ # padding
485
+ _, _, H, W = x.size()
486
+ if W % self.patch_size[1] != 0:
487
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
488
+ if H % self.patch_size[0] != 0:
489
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
490
+
491
+ x = self.proj(x) # B C Wh Ww
492
+ if self.norm is not None:
493
+ Wh, Ww = x.size(2), x.size(3)
494
+ x = x.flatten(2).transpose(1, 2)
495
+ x = self.norm(x)
496
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
497
+
498
+ return x
499
+
500
+
501
+ class SwinTransformer(nn.Module):
502
+ """Swin Transformer backbone.
503
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
504
+ https://arxiv.org/pdf/2103.14030
505
+ Args:
506
+ pretrain_img_size (int): Input image size for training the pretrained model,
507
+ used in absolute postion embedding. Default 224.
508
+ patch_size (int | tuple(int)): Patch size. Default: 4.
509
+ in_chans (int): Number of input image channels. Default: 3.
510
+ embed_dim (int): Number of linear projection output channels. Default: 96.
511
+ depths (tuple[int]): Depths of each Swin Transformer stage.
512
+ num_heads (tuple[int]): Number of attention head of each stage.
513
+ window_size (int): Window size. Default: 7.
514
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
515
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
516
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
517
+ drop_rate (float): Dropout rate.
518
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
519
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
520
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
521
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
522
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
523
+ out_indices (Sequence[int]): Output from which stages.
524
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
525
+ -1 means not freezing any parameters.
526
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
527
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ pretrain_img_size=224,
533
+ patch_size=4,
534
+ in_chans=3,
535
+ embed_dim=96,
536
+ depths=[2, 2, 6, 2],
537
+ num_heads=[3, 6, 12, 24],
538
+ window_size=7,
539
+ mlp_ratio=4.0,
540
+ qkv_bias=True,
541
+ qk_scale=None,
542
+ drop_rate=0.0,
543
+ attn_drop_rate=0.0,
544
+ drop_path_rate=0.2,
545
+ norm_layer=nn.LayerNorm,
546
+ ape=False,
547
+ patch_norm=True,
548
+ out_indices=(0, 1, 2, 3),
549
+ frozen_stages=-1,
550
+ dilation=False,
551
+ use_checkpoint=False,
552
+ ):
553
+ super().__init__()
554
+
555
+ self.pretrain_img_size = pretrain_img_size
556
+ self.num_layers = len(depths)
557
+ self.embed_dim = embed_dim
558
+ self.ape = ape
559
+ self.patch_norm = patch_norm
560
+ self.out_indices = out_indices
561
+ self.frozen_stages = frozen_stages
562
+ self.dilation = dilation
563
+
564
+ # if use_checkpoint:
565
+ # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
566
+
567
+ # split image into non-overlapping patches
568
+ self.patch_embed = PatchEmbed(
569
+ patch_size=patch_size,
570
+ in_chans=in_chans,
571
+ embed_dim=embed_dim,
572
+ norm_layer=norm_layer if self.patch_norm else None,
573
+ )
574
+
575
+ # absolute position embedding
576
+ if self.ape:
577
+ pretrain_img_size = to_2tuple(pretrain_img_size)
578
+ patch_size = to_2tuple(patch_size)
579
+ patches_resolution = [
580
+ pretrain_img_size[0] // patch_size[0],
581
+ pretrain_img_size[1] // patch_size[1],
582
+ ]
583
+
584
+ self.absolute_pos_embed = nn.Parameter(
585
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
586
+ )
587
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
588
+
589
+ self.pos_drop = nn.Dropout(p=drop_rate)
590
+
591
+ # stochastic depth
592
+ dpr = [
593
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
594
+ ] # stochastic depth decay rule
595
+
596
+ # build layers
597
+ self.layers = nn.ModuleList()
598
+ # prepare downsample list
599
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
600
+ downsamplelist[-1] = None
601
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
602
+ if self.dilation:
603
+ downsamplelist[-2] = None
604
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
605
+ for i_layer in range(self.num_layers):
606
+ layer = BasicLayer(
607
+ # dim=int(embed_dim * 2 ** i_layer),
608
+ dim=num_features[i_layer],
609
+ depth=depths[i_layer],
610
+ num_heads=num_heads[i_layer],
611
+ window_size=window_size,
612
+ mlp_ratio=mlp_ratio,
613
+ qkv_bias=qkv_bias,
614
+ qk_scale=qk_scale,
615
+ drop=drop_rate,
616
+ attn_drop=attn_drop_rate,
617
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
618
+ norm_layer=norm_layer,
619
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
620
+ downsample=downsamplelist[i_layer],
621
+ use_checkpoint=use_checkpoint,
622
+ )
623
+ self.layers.append(layer)
624
+
625
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
626
+ self.num_features = num_features
627
+
628
+ # add a norm layer for each output
629
+ for i_layer in out_indices:
630
+ layer = norm_layer(num_features[i_layer])
631
+ layer_name = f"norm{i_layer}"
632
+ self.add_module(layer_name, layer)
633
+
634
+ self._freeze_stages()
635
+
636
+ def _freeze_stages(self):
637
+ if self.frozen_stages >= 0:
638
+ self.patch_embed.eval()
639
+ for param in self.patch_embed.parameters():
640
+ param.requires_grad = False
641
+
642
+ if self.frozen_stages >= 1 and self.ape:
643
+ self.absolute_pos_embed.requires_grad = False
644
+
645
+ if self.frozen_stages >= 2:
646
+ self.pos_drop.eval()
647
+ for i in range(0, self.frozen_stages - 1):
648
+ m = self.layers[i]
649
+ m.eval()
650
+ for param in m.parameters():
651
+ param.requires_grad = False
652
+
653
+ # def init_weights(self, pretrained=None):
654
+ # """Initialize the weights in backbone.
655
+ # Args:
656
+ # pretrained (str, optional): Path to pre-trained weights.
657
+ # Defaults to None.
658
+ # """
659
+
660
+ # def _init_weights(m):
661
+ # if isinstance(m, nn.Linear):
662
+ # trunc_normal_(m.weight, std=.02)
663
+ # if isinstance(m, nn.Linear) and m.bias is not None:
664
+ # nn.init.constant_(m.bias, 0)
665
+ # elif isinstance(m, nn.LayerNorm):
666
+ # nn.init.constant_(m.bias, 0)
667
+ # nn.init.constant_(m.weight, 1.0)
668
+
669
+ # if isinstance(pretrained, str):
670
+ # self.apply(_init_weights)
671
+ # logger = get_root_logger()
672
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
673
+ # elif pretrained is None:
674
+ # self.apply(_init_weights)
675
+ # else:
676
+ # raise TypeError('pretrained must be a str or None')
677
+
678
+ def forward_raw(self, x):
679
+ """Forward function."""
680
+ x = self.patch_embed(x)
681
+
682
+ Wh, Ww = x.size(2), x.size(3)
683
+ if self.ape:
684
+ # interpolate the position embedding to the corresponding size
685
+ absolute_pos_embed = F.interpolate(
686
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
687
+ )
688
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
689
+ else:
690
+ x = x.flatten(2).transpose(1, 2)
691
+ x = self.pos_drop(x)
692
+
693
+ outs = []
694
+ for i in range(self.num_layers):
695
+ layer = self.layers[i]
696
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
697
+ # import ipdb; ipdb.set_trace()
698
+
699
+ if i in self.out_indices:
700
+ norm_layer = getattr(self, f"norm{i}")
701
+ x_out = norm_layer(x_out)
702
+
703
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
704
+ outs.append(out)
705
+ # in:
706
+ # torch.Size([2, 3, 1024, 1024])
707
+ # outs:
708
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
709
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
710
+ return tuple(outs)
711
+
712
+ def forward(self, tensor_list: NestedTensor):
713
+ x = tensor_list.tensors
714
+
715
+ """Forward function."""
716
+ x = self.patch_embed(x)
717
+
718
+ Wh, Ww = x.size(2), x.size(3)
719
+ if self.ape:
720
+ # interpolate the position embedding to the corresponding size
721
+ absolute_pos_embed = F.interpolate(
722
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
723
+ )
724
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
725
+ else:
726
+ x = x.flatten(2).transpose(1, 2)
727
+ x = self.pos_drop(x)
728
+
729
+ outs = []
730
+ for i in range(self.num_layers):
731
+ layer = self.layers[i]
732
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
733
+
734
+ if i in self.out_indices:
735
+ norm_layer = getattr(self, f"norm{i}")
736
+ x_out = norm_layer(x_out)
737
+
738
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
739
+ outs.append(out)
740
+ # in:
741
+ # torch.Size([2, 3, 1024, 1024])
742
+ # out:
743
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
744
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
745
+
746
+ # collect for nesttensors
747
+ outs_dict = {}
748
+ for idx, out_i in enumerate(outs):
749
+ m = tensor_list.mask
750
+ assert m is not None
751
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
752
+ outs_dict[idx] = NestedTensor(out_i, mask)
753
+
754
+ return outs_dict
755
+
756
+ def train(self, mode=True):
757
+ """Convert the model into training mode while keep layers freezed."""
758
+ super(SwinTransformer, self).train(mode)
759
+ self._freeze_stages()
760
+
761
+
762
+ def build_swin_transformer(modelname, pretrain_img_size, **kw):
763
+ assert modelname in [
764
+ "swin_T_224_1k",
765
+ "swin_B_224_22k",
766
+ "swin_B_384_22k",
767
+ "swin_L_224_22k",
768
+ "swin_L_384_22k",
769
+ ]
770
+
771
+ model_para_dict = {
772
+ "swin_T_224_1k": dict(
773
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
774
+ ),
775
+ "swin_B_224_22k": dict(
776
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
777
+ ),
778
+ "swin_B_384_22k": dict(
779
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
780
+ ),
781
+ "swin_L_224_22k": dict(
782
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
783
+ ),
784
+ "swin_L_384_22k": dict(
785
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
786
+ ),
787
+ }
788
+ kw_cgf = model_para_dict[modelname]
789
+ kw_cgf.update(kw)
790
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
791
+ return model
792
+
793
+
794
+ if __name__ == "__main__":
795
+ model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
796
+ x = torch.rand(2, 3, 1024, 1024)
797
+ y = model.forward_raw(x)
798
+ import ipdb
799
+
800
+ ipdb.set_trace()
801
+ x = torch.rand(2, 3, 384, 384)
802
+ y = model.forward_raw(x)
GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from torch import Tensor, nn
12
+ from torchvision.ops.boxes import nms
13
+ from transformers import BertConfig, BertModel, BertPreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
+
16
+
17
+ class BertModelWarper(nn.Module):
18
+ def __init__(self, bert_model):
19
+ super().__init__()
20
+ # self.bert = bert_modelc
21
+
22
+ self.config = bert_model.config
23
+ self.embeddings = bert_model.embeddings
24
+ self.encoder = bert_model.encoder
25
+ self.pooler = bert_model.pooler
26
+
27
+ self.get_extended_attention_mask = bert_model.get_extended_attention_mask
28
+ self.invert_attention_mask = bert_model.invert_attention_mask
29
+ self.get_head_mask = bert_model.get_head_mask
30
+
31
+ def forward(
32
+ self,
33
+ input_ids=None,
34
+ attention_mask=None,
35
+ token_type_ids=None,
36
+ position_ids=None,
37
+ head_mask=None,
38
+ inputs_embeds=None,
39
+ encoder_hidden_states=None,
40
+ encoder_attention_mask=None,
41
+ past_key_values=None,
42
+ use_cache=None,
43
+ output_attentions=None,
44
+ output_hidden_states=None,
45
+ return_dict=None,
46
+ ):
47
+ r"""
48
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
49
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
50
+ the model is configured as a decoder.
51
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
52
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
53
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
54
+
55
+ - 1 for tokens that are **not masked**,
56
+ - 0 for tokens that are **masked**.
57
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
58
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
59
+
60
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
61
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
62
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
63
+ use_cache (:obj:`bool`, `optional`):
64
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
65
+ decoding (see :obj:`past_key_values`).
66
+ """
67
+ output_attentions = (
68
+ output_attentions if output_attentions is not None else self.config.output_attentions
69
+ )
70
+ output_hidden_states = (
71
+ output_hidden_states
72
+ if output_hidden_states is not None
73
+ else self.config.output_hidden_states
74
+ )
75
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
+
77
+ if self.config.is_decoder:
78
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
79
+ else:
80
+ use_cache = False
81
+
82
+ if input_ids is not None and inputs_embeds is not None:
83
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
84
+ elif input_ids is not None:
85
+ input_shape = input_ids.size()
86
+ batch_size, seq_length = input_shape
87
+ elif inputs_embeds is not None:
88
+ input_shape = inputs_embeds.size()[:-1]
89
+ batch_size, seq_length = input_shape
90
+ else:
91
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
92
+
93
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
94
+
95
+ # past_key_values_length
96
+ past_key_values_length = (
97
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
98
+ )
99
+
100
+ if attention_mask is None:
101
+ attention_mask = torch.ones(
102
+ ((batch_size, seq_length + past_key_values_length)), device=device
103
+ )
104
+ if token_type_ids is None:
105
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
106
+
107
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
108
+ # ourselves in which case we just need to make it broadcastable to all heads.
109
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
110
+ attention_mask, input_shape, device
111
+ )
112
+
113
+ # If a 2D or 3D attention mask is provided for the cross-attention
114
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
115
+ if self.config.is_decoder and encoder_hidden_states is not None:
116
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
117
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
118
+ if encoder_attention_mask is None:
119
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
120
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
121
+ else:
122
+ encoder_extended_attention_mask = None
123
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
124
+ # import ipdb; ipdb.set_trace()
125
+
126
+ # Prepare head mask if needed
127
+ # 1.0 in head_mask indicate we keep the head
128
+ # attention_probs has shape bsz x n_heads x N x N
129
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
130
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
131
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
132
+
133
+ embedding_output = self.embeddings(
134
+ input_ids=input_ids,
135
+ position_ids=position_ids,
136
+ token_type_ids=token_type_ids,
137
+ inputs_embeds=inputs_embeds,
138
+ past_key_values_length=past_key_values_length,
139
+ )
140
+
141
+ encoder_outputs = self.encoder(
142
+ embedding_output,
143
+ attention_mask=extended_attention_mask,
144
+ head_mask=head_mask,
145
+ encoder_hidden_states=encoder_hidden_states,
146
+ encoder_attention_mask=encoder_extended_attention_mask,
147
+ past_key_values=past_key_values,
148
+ use_cache=use_cache,
149
+ output_attentions=output_attentions,
150
+ output_hidden_states=output_hidden_states,
151
+ return_dict=return_dict,
152
+ )
153
+ sequence_output = encoder_outputs[0]
154
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
155
+
156
+ if not return_dict:
157
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
158
+
159
+ return BaseModelOutputWithPoolingAndCrossAttentions(
160
+ last_hidden_state=sequence_output,
161
+ pooler_output=pooled_output,
162
+ past_key_values=encoder_outputs.past_key_values,
163
+ hidden_states=encoder_outputs.hidden_states,
164
+ attentions=encoder_outputs.attentions,
165
+ cross_attentions=encoder_outputs.cross_attentions,
166
+ )
167
+
168
+
169
+ class TextEncoderShell(nn.Module):
170
+ def __init__(self, text_encoder):
171
+ super().__init__()
172
+ self.text_encoder = text_encoder
173
+ self.config = self.text_encoder.config
174
+
175
+ def forward(self, **kw):
176
+ # feed into text encoder
177
+ return self.text_encoder(**kw)
178
+
179
+
180
+ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
181
+ """Generate attention mask between each pair of special tokens
182
+ Args:
183
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
184
+ special_tokens_mask (list): special tokens mask.
185
+ Returns:
186
+ torch.Tensor: attention mask between each special tokens.
187
+ """
188
+ input_ids = tokenized["input_ids"]
189
+ bs, num_token = input_ids.shape
190
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
191
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
192
+ for special_token in special_tokens_list:
193
+ special_tokens_mask |= input_ids == special_token
194
+
195
+ # idxs: each row is a list of indices of special tokens
196
+ idxs = torch.nonzero(special_tokens_mask)
197
+
198
+ # generate attention mask and positional ids
199
+ attention_mask = (
200
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
201
+ )
202
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
203
+ previous_col = 0
204
+ for i in range(idxs.shape[0]):
205
+ row, col = idxs[i]
206
+ if (col == 0) or (col == num_token - 1):
207
+ attention_mask[row, col, col] = True
208
+ position_ids[row, col] = 0
209
+ else:
210
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
211
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
212
+ 0, col - previous_col, device=input_ids.device
213
+ )
214
+
215
+ previous_col = col
216
+
217
+ # # padding mask
218
+ # padding_mask = tokenized['attention_mask']
219
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
220
+
221
+ return attention_mask, position_ids.to(torch.long)
222
+
223
+
224
+ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
225
+ """Generate attention mask between each pair of special tokens
226
+ Args:
227
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
228
+ special_tokens_mask (list): special tokens mask.
229
+ Returns:
230
+ torch.Tensor: attention mask between each special tokens.
231
+ """
232
+ input_ids = tokenized["input_ids"]
233
+ bs, num_token = input_ids.shape
234
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
235
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
236
+ for special_token in special_tokens_list:
237
+ special_tokens_mask |= input_ids == special_token
238
+
239
+ # idxs: each row is a list of indices of special tokens
240
+ idxs = torch.nonzero(special_tokens_mask)
241
+
242
+ # generate attention mask and positional ids
243
+ attention_mask = (
244
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
245
+ )
246
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
247
+ cate_to_token_mask_list = [[] for _ in range(bs)]
248
+ previous_col = 0
249
+ for i in range(idxs.shape[0]):
250
+ row, col = idxs[i]
251
+ if (col == 0) or (col == num_token - 1):
252
+ attention_mask[row, col, col] = True
253
+ position_ids[row, col] = 0
254
+ else:
255
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
256
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
257
+ 0, col - previous_col, device=input_ids.device
258
+ )
259
+ c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
260
+ c2t_maski[previous_col + 1 : col] = True
261
+ cate_to_token_mask_list[row].append(c2t_maski)
262
+ previous_col = col
263
+
264
+ cate_to_token_mask_list = [
265
+ torch.stack(cate_to_token_mask_listi, dim=0)
266
+ for cate_to_token_mask_listi in cate_to_token_mask_list
267
+ ]
268
+
269
+ # # padding mask
270
+ # padding_mask = tokenized['attention_mask']
271
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
272
+
273
+ return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor
22
+ ms_deform_attn_forward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const int im2col_step)
29
+ {
30
+ if (value.type().is_cuda())
31
+ {
32
+ #ifdef WITH_CUDA
33
+ return ms_deform_attn_cuda_forward(
34
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
35
+ #else
36
+ AT_ERROR("Not compiled with GPU support");
37
+ #endif
38
+ }
39
+ AT_ERROR("Not implemented on the CPU");
40
+ }
41
+
42
+ std::vector<at::Tensor>
43
+ ms_deform_attn_backward(
44
+ const at::Tensor &value,
45
+ const at::Tensor &spatial_shapes,
46
+ const at::Tensor &level_start_index,
47
+ const at::Tensor &sampling_loc,
48
+ const at::Tensor &attn_weight,
49
+ const at::Tensor &grad_output,
50
+ const int im2col_step)
51
+ {
52
+ if (value.type().is_cuda())
53
+ {
54
+ #ifdef WITH_CUDA
55
+ return ms_deform_attn_cuda_backward(
56
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
57
+ #else
58
+ AT_ERROR("Not compiled with GPU support");
59
+ #endif
60
+ }
61
+ AT_ERROR("Not implemented on the CPU");
62
+ }
63
+
64
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+ namespace groundingdino {
17
+
18
+ at::Tensor
19
+ ms_deform_attn_cpu_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step)
26
+ {
27
+ AT_ERROR("Not implement on cpu");
28
+ }
29
+
30
+ std::vector<at::Tensor>
31
+ ms_deform_attn_cpu_backward(
32
+ const at::Tensor &value,
33
+ const at::Tensor &spatial_shapes,
34
+ const at::Tensor &level_start_index,
35
+ const at::Tensor &sampling_loc,
36
+ const at::Tensor &attn_weight,
37
+ const at::Tensor &grad_output,
38
+ const int im2col_step)
39
+ {
40
+ AT_ERROR("Not implement on cpu");
41
+ }
42
+
43
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ namespace groundingdino {
15
+
16
+ at::Tensor
17
+ ms_deform_attn_cpu_forward(
18
+ const at::Tensor &value,
19
+ const at::Tensor &spatial_shapes,
20
+ const at::Tensor &level_start_index,
21
+ const at::Tensor &sampling_loc,
22
+ const at::Tensor &attn_weight,
23
+ const int im2col_step);
24
+
25
+ std::vector<at::Tensor>
26
+ ms_deform_attn_cpu_backward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const at::Tensor &grad_output,
33
+ const int im2col_step);
34
+
35
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor ms_deform_attn_cuda_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
30
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
31
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
32
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
33
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
34
+
35
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
36
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
37
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
38
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
39
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
40
+
41
+ const int batch = value.size(0);
42
+ const int spatial_size = value.size(1);
43
+ const int num_heads = value.size(2);
44
+ const int channels = value.size(3);
45
+
46
+ const int num_levels = spatial_shapes.size(0);
47
+
48
+ const int num_query = sampling_loc.size(1);
49
+ const int num_point = sampling_loc.size(4);
50
+
51
+ const int im2col_step_ = std::min(batch, im2col_step);
52
+
53
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
54
+
55
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
56
+
57
+ const int batch_n = im2col_step_;
58
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
59
+ auto per_value_size = spatial_size * num_heads * channels;
60
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
61
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
62
+ for (int n = 0; n < batch/im2col_step_; ++n)
63
+ {
64
+ auto columns = output_n.select(0, n);
65
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
66
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
67
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
68
+ spatial_shapes.data<int64_t>(),
69
+ level_start_index.data<int64_t>(),
70
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
71
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
72
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
73
+ columns.data<scalar_t>());
74
+
75
+ }));
76
+ }
77
+
78
+ output = output.view({batch, num_query, num_heads*channels});
79
+
80
+ return output;
81
+ }
82
+
83
+
84
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
85
+ const at::Tensor &value,
86
+ const at::Tensor &spatial_shapes,
87
+ const at::Tensor &level_start_index,
88
+ const at::Tensor &sampling_loc,
89
+ const at::Tensor &attn_weight,
90
+ const at::Tensor &grad_output,
91
+ const int im2col_step)
92
+ {
93
+
94
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
95
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
96
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
97
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
98
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
99
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
100
+
101
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
102
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
103
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
104
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
105
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
106
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
107
+
108
+ const int batch = value.size(0);
109
+ const int spatial_size = value.size(1);
110
+ const int num_heads = value.size(2);
111
+ const int channels = value.size(3);
112
+
113
+ const int num_levels = spatial_shapes.size(0);
114
+
115
+ const int num_query = sampling_loc.size(1);
116
+ const int num_point = sampling_loc.size(4);
117
+
118
+ const int im2col_step_ = std::min(batch, im2col_step);
119
+
120
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
121
+
122
+ auto grad_value = at::zeros_like(value);
123
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
124
+ auto grad_attn_weight = at::zeros_like(attn_weight);
125
+
126
+ const int batch_n = im2col_step_;
127
+ auto per_value_size = spatial_size * num_heads * channels;
128
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
129
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
130
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
131
+
132
+ for (int n = 0; n < batch/im2col_step_; ++n)
133
+ {
134
+ auto grad_output_g = grad_output_n.select(0, n);
135
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
136
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
137
+ grad_output_g.data<scalar_t>(),
138
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
139
+ spatial_shapes.data<int64_t>(),
140
+ level_start_index.data<int64_t>(),
141
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
142
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
143
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
144
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
145
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
147
+
148
+ }));
149
+ }
150
+
151
+ return {
152
+ grad_value, grad_sampling_loc, grad_attn_weight
153
+ };
154
+ }
155
+
156
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ namespace groundingdino {
15
+
16
+ at::Tensor ms_deform_attn_cuda_forward(
17
+ const at::Tensor &value,
18
+ const at::Tensor &spatial_shapes,
19
+ const at::Tensor &level_start_index,
20
+ const at::Tensor &sampling_loc,
21
+ const at::Tensor &attn_weight,
22
+ const int im2col_step);
23
+
24
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
33
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #include <cuda_runtime_api.h>
2
+
3
+ namespace groundingdino {
4
+ int get_cudart_version() {
5
+ return CUDART_VERSION;
6
+ }
7
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ #include "MsDeformAttn/ms_deform_attn.h"
4
+
5
+ namespace groundingdino {
6
+
7
+ #ifdef WITH_CUDA
8
+ extern int get_cudart_version();
9
+ #endif
10
+
11
+ std::string get_cuda_version() {
12
+ #ifdef WITH_CUDA
13
+ std::ostringstream oss;
14
+
15
+ // copied from
16
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
17
+ auto printCudaStyleVersion = [&](int v) {
18
+ oss << (v / 1000) << "." << (v / 10 % 100);
19
+ if (v % 10 != 0) {
20
+ oss << "." << (v % 10);
21
+ }
22
+ };
23
+ printCudaStyleVersion(get_cudart_version());
24
+ return oss.str();
25
+ #else
26
+ return std::string("not available");
27
+ #endif
28
+ }
29
+
30
+ // similar to
31
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
32
+ std::string get_compiler_version() {
33
+ std::ostringstream ss;
34
+ #if defined(__GNUC__)
35
+ #ifndef __clang__
36
+ { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
37
+ #endif
38
+ #endif
39
+
40
+ #if defined(__clang_major__)
41
+ {
42
+ ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
43
+ << __clang_patchlevel__;
44
+ }
45
+ #endif
46
+
47
+ #if defined(_MSC_VER)
48
+ { ss << "MSVC " << _MSC_FULL_VER; }
49
+ #endif
50
+ return ss.str();
51
+ }
52
+
53
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
54
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
55
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
56
+ }
57
+
58
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/fuse_modules.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from timm.models.layers import DropPath
12
+
13
+
14
+ class FeatureResizer(nn.Module):
15
+ """
16
+ This class takes as input a set of embeddings of dimension C1 and outputs a set of
17
+ embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
18
+ """
19
+
20
+ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
21
+ super().__init__()
22
+ self.do_ln = do_ln
23
+ # Object feature encoding
24
+ self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
25
+ self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ def forward(self, encoder_features):
29
+ x = self.fc(encoder_features)
30
+ if self.do_ln:
31
+ x = self.layer_norm(x)
32
+ output = self.dropout(x)
33
+ return output
34
+
35
+
36
+ def l1norm(X, dim, eps=1e-8):
37
+ """L1-normalize columns of X"""
38
+ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
39
+ X = torch.div(X, norm)
40
+ return X
41
+
42
+
43
+ def l2norm(X, dim, eps=1e-8):
44
+ """L2-normalize columns of X"""
45
+ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
46
+ X = torch.div(X, norm)
47
+ return X
48
+
49
+
50
+ def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
51
+ """
52
+ query: (n_context, queryL, d)
53
+ context: (n_context, sourceL, d)
54
+ """
55
+ batch_size_q, queryL = query.size(0), query.size(1)
56
+ batch_size, sourceL = context.size(0), context.size(1)
57
+
58
+ # Get attention
59
+ # --> (batch, d, queryL)
60
+ queryT = torch.transpose(query, 1, 2)
61
+
62
+ # (batch, sourceL, d)(batch, d, queryL)
63
+ # --> (batch, sourceL, queryL)
64
+ attn = torch.bmm(context, queryT)
65
+ if raw_feature_norm == "softmax":
66
+ # --> (batch*sourceL, queryL)
67
+ attn = attn.view(batch_size * sourceL, queryL)
68
+ attn = nn.Softmax()(attn)
69
+ # --> (batch, sourceL, queryL)
70
+ attn = attn.view(batch_size, sourceL, queryL)
71
+ elif raw_feature_norm == "l2norm":
72
+ attn = l2norm(attn, 2)
73
+ elif raw_feature_norm == "clipped_l2norm":
74
+ attn = nn.LeakyReLU(0.1)(attn)
75
+ attn = l2norm(attn, 2)
76
+ else:
77
+ raise ValueError("unknown first norm type:", raw_feature_norm)
78
+ # --> (batch, queryL, sourceL)
79
+ attn = torch.transpose(attn, 1, 2).contiguous()
80
+ # --> (batch*queryL, sourceL)
81
+ attn = attn.view(batch_size * queryL, sourceL)
82
+ attn = nn.Softmax()(attn * smooth)
83
+ # --> (batch, queryL, sourceL)
84
+ attn = attn.view(batch_size, queryL, sourceL)
85
+ # --> (batch, sourceL, queryL)
86
+ attnT = torch.transpose(attn, 1, 2).contiguous()
87
+
88
+ # --> (batch, d, sourceL)
89
+ contextT = torch.transpose(context, 1, 2)
90
+ # (batch x d x sourceL)(batch x sourceL x queryL)
91
+ # --> (batch, d, queryL)
92
+ weightedContext = torch.bmm(contextT, attnT)
93
+ # --> (batch, queryL, d)
94
+ weightedContext = torch.transpose(weightedContext, 1, 2)
95
+
96
+ return weightedContext, attnT
97
+
98
+
99
+ class BiMultiHeadAttention(nn.Module):
100
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
101
+ super(BiMultiHeadAttention, self).__init__()
102
+
103
+ self.embed_dim = embed_dim
104
+ self.num_heads = num_heads
105
+ self.head_dim = embed_dim // num_heads
106
+ self.v_dim = v_dim
107
+ self.l_dim = l_dim
108
+
109
+ assert (
110
+ self.head_dim * self.num_heads == self.embed_dim
111
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
112
+ self.scale = self.head_dim ** (-0.5)
113
+ self.dropout = dropout
114
+
115
+ self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
116
+ self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
117
+ self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
118
+ self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
119
+
120
+ self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
121
+ self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
122
+
123
+ self.stable_softmax_2d = True
124
+ self.clamp_min_for_underflow = True
125
+ self.clamp_max_for_overflow = True
126
+
127
+ self._reset_parameters()
128
+
129
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
130
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
131
+
132
+ def _reset_parameters(self):
133
+ nn.init.xavier_uniform_(self.v_proj.weight)
134
+ self.v_proj.bias.data.fill_(0)
135
+ nn.init.xavier_uniform_(self.l_proj.weight)
136
+ self.l_proj.bias.data.fill_(0)
137
+ nn.init.xavier_uniform_(self.values_v_proj.weight)
138
+ self.values_v_proj.bias.data.fill_(0)
139
+ nn.init.xavier_uniform_(self.values_l_proj.weight)
140
+ self.values_l_proj.bias.data.fill_(0)
141
+ nn.init.xavier_uniform_(self.out_v_proj.weight)
142
+ self.out_v_proj.bias.data.fill_(0)
143
+ nn.init.xavier_uniform_(self.out_l_proj.weight)
144
+ self.out_l_proj.bias.data.fill_(0)
145
+
146
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
147
+ """_summary_
148
+
149
+ Args:
150
+ v (_type_): bs, n_img, dim
151
+ l (_type_): bs, n_text, dim
152
+ attention_mask_v (_type_, optional): _description_. bs, n_img
153
+ attention_mask_l (_type_, optional): _description_. bs, n_text
154
+
155
+ Returns:
156
+ _type_: _description_
157
+ """
158
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
159
+ # import ipdb; ipdb.set_trace()
160
+ bsz, tgt_len, _ = v.size()
161
+
162
+ query_states = self.v_proj(v) * self.scale
163
+ key_states = self._shape(self.l_proj(l), -1, bsz)
164
+ value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
165
+ value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
166
+
167
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
168
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
169
+ key_states = key_states.view(*proj_shape)
170
+ value_v_states = value_v_states.view(*proj_shape)
171
+ value_l_states = value_l_states.view(*proj_shape)
172
+
173
+ src_len = key_states.size(1)
174
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
175
+
176
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
177
+ raise ValueError(
178
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
179
+ )
180
+
181
+ if self.stable_softmax_2d:
182
+ attn_weights = attn_weights - attn_weights.max()
183
+
184
+ if self.clamp_min_for_underflow:
185
+ attn_weights = torch.clamp(
186
+ attn_weights, min=-50000
187
+ ) # Do not increase -50000, data type half has quite limited range
188
+ if self.clamp_max_for_overflow:
189
+ attn_weights = torch.clamp(
190
+ attn_weights, max=50000
191
+ ) # Do not increase 50000, data type half has quite limited range
192
+
193
+ attn_weights_T = attn_weights.transpose(1, 2)
194
+ attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
195
+ if self.clamp_min_for_underflow:
196
+ attn_weights_l = torch.clamp(
197
+ attn_weights_l, min=-50000
198
+ ) # Do not increase -50000, data type half has quite limited range
199
+ if self.clamp_max_for_overflow:
200
+ attn_weights_l = torch.clamp(
201
+ attn_weights_l, max=50000
202
+ ) # Do not increase 50000, data type half has quite limited range
203
+
204
+ # mask vison for language
205
+ if attention_mask_v is not None:
206
+ attention_mask_v = (
207
+ attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
208
+ )
209
+ attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
210
+
211
+ attn_weights_l = attn_weights_l.softmax(dim=-1)
212
+
213
+ # mask language for vision
214
+ if attention_mask_l is not None:
215
+ attention_mask_l = (
216
+ attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
217
+ )
218
+ attn_weights.masked_fill_(attention_mask_l, float("-inf"))
219
+ attn_weights_v = attn_weights.softmax(dim=-1)
220
+
221
+ attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
222
+ attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
223
+
224
+ attn_output_v = torch.bmm(attn_probs_v, value_l_states)
225
+ attn_output_l = torch.bmm(attn_probs_l, value_v_states)
226
+
227
+ if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
228
+ raise ValueError(
229
+ f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
230
+ )
231
+
232
+ if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
233
+ raise ValueError(
234
+ f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
235
+ )
236
+
237
+ attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
238
+ attn_output_v = attn_output_v.transpose(1, 2)
239
+ attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
240
+
241
+ attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
242
+ attn_output_l = attn_output_l.transpose(1, 2)
243
+ attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
244
+
245
+ attn_output_v = self.out_v_proj(attn_output_v)
246
+ attn_output_l = self.out_l_proj(attn_output_l)
247
+
248
+ return attn_output_v, attn_output_l
249
+
250
+
251
+ # Bi-Direction MHA (text->image, image->text)
252
+ class BiAttentionBlock(nn.Module):
253
+ def __init__(
254
+ self,
255
+ v_dim,
256
+ l_dim,
257
+ embed_dim,
258
+ num_heads,
259
+ dropout=0.1,
260
+ drop_path=0.0,
261
+ init_values=1e-4,
262
+ cfg=None,
263
+ ):
264
+ """
265
+ Inputs:
266
+ embed_dim - Dimensionality of input and attention feature vectors
267
+ hidden_dim - Dimensionality of hidden layer in feed-forward network
268
+ (usually 2-4x larger than embed_dim)
269
+ num_heads - Number of heads to use in the Multi-Head Attention block
270
+ dropout - Amount of dropout to apply in the feed-forward network
271
+ """
272
+ super(BiAttentionBlock, self).__init__()
273
+
274
+ # pre layer norm
275
+ self.layer_norm_v = nn.LayerNorm(v_dim)
276
+ self.layer_norm_l = nn.LayerNorm(l_dim)
277
+ self.attn = BiMultiHeadAttention(
278
+ v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
279
+ )
280
+
281
+ # add layer scale for training stability
282
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
283
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
284
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
285
+
286
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
287
+ v = self.layer_norm_v(v)
288
+ l = self.layer_norm_l(l)
289
+ delta_v, delta_l = self.attn(
290
+ v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
291
+ )
292
+ # v, l = v + delta_v, l + delta_l
293
+ v = v + self.drop_path(self.gamma_v * delta_v)
294
+ l = l + self.drop_path(self.gamma_l * delta_l)
295
+ return v, l
296
+
297
+ # def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR model and criterion classes.
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Modified from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
15
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
16
+ # ------------------------------------------------------------------------
17
+ import copy
18
+ from typing import List
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from torchvision.ops.boxes import nms
24
+ from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
25
+
26
+ from groundingdino.util import box_ops, get_tokenlizer
27
+ from groundingdino.util.misc import (
28
+ NestedTensor,
29
+ accuracy,
30
+ get_world_size,
31
+ interpolate,
32
+ inverse_sigmoid,
33
+ is_dist_avail_and_initialized,
34
+ nested_tensor_from_tensor_list,
35
+ )
36
+ from groundingdino.util.utils import get_phrases_from_posmap
37
+ from groundingdino.util.visualizer import COCOVisualizer
38
+ from groundingdino.util.vl_utils import create_positive_map_from_span
39
+
40
+ from ..registry import MODULE_BUILD_FUNCS
41
+ from .backbone import build_backbone
42
+ from .bertwarper import (
43
+ BertModelWarper,
44
+ generate_masks_with_special_tokens,
45
+ generate_masks_with_special_tokens_and_transfer_map,
46
+ )
47
+ from .transformer import build_transformer
48
+ from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
49
+
50
+
51
+ class GroundingDINO(nn.Module):
52
+ """This is the Cross-Attention Detector module that performs object detection"""
53
+
54
+ def __init__(
55
+ self,
56
+ backbone,
57
+ transformer,
58
+ num_queries,
59
+ aux_loss=False,
60
+ iter_update=False,
61
+ query_dim=2,
62
+ num_feature_levels=1,
63
+ nheads=8,
64
+ # two stage
65
+ two_stage_type="no", # ['no', 'standard']
66
+ dec_pred_bbox_embed_share=True,
67
+ two_stage_class_embed_share=True,
68
+ two_stage_bbox_embed_share=True,
69
+ num_patterns=0,
70
+ dn_number=100,
71
+ dn_box_noise_scale=0.4,
72
+ dn_label_noise_ratio=0.5,
73
+ dn_labelbook_size=100,
74
+ text_encoder_type="bert-base-uncased",
75
+ sub_sentence_present=True,
76
+ max_text_len=256,
77
+ ):
78
+ """Initializes the model.
79
+ Parameters:
80
+ backbone: torch module of the backbone to be used. See backbone.py
81
+ transformer: torch module of the transformer architecture. See transformer.py
82
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
83
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
84
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
85
+ """
86
+ super().__init__()
87
+ self.num_queries = num_queries
88
+ self.transformer = transformer
89
+ self.hidden_dim = hidden_dim = transformer.d_model
90
+ self.num_feature_levels = num_feature_levels
91
+ self.nheads = nheads
92
+ self.max_text_len = 256
93
+ self.sub_sentence_present = sub_sentence_present
94
+
95
+ # setting query dim
96
+ self.query_dim = query_dim
97
+ assert query_dim == 4
98
+
99
+ # for dn training
100
+ self.num_patterns = num_patterns
101
+ self.dn_number = dn_number
102
+ self.dn_box_noise_scale = dn_box_noise_scale
103
+ self.dn_label_noise_ratio = dn_label_noise_ratio
104
+ self.dn_labelbook_size = dn_labelbook_size
105
+
106
+ # bert
107
+ self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
108
+ self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
109
+ self.bert.pooler.dense.weight.requires_grad_(False)
110
+ self.bert.pooler.dense.bias.requires_grad_(False)
111
+ self.bert = BertModelWarper(bert_model=self.bert)
112
+
113
+ self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
114
+ nn.init.constant_(self.feat_map.bias.data, 0)
115
+ nn.init.xavier_uniform_(self.feat_map.weight.data)
116
+ # freeze
117
+
118
+ # special tokens
119
+ self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
120
+
121
+ # prepare input projection layers
122
+ if num_feature_levels > 1:
123
+ num_backbone_outs = len(backbone.num_channels)
124
+ input_proj_list = []
125
+ for _ in range(num_backbone_outs):
126
+ in_channels = backbone.num_channels[_]
127
+ input_proj_list.append(
128
+ nn.Sequential(
129
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
130
+ nn.GroupNorm(32, hidden_dim),
131
+ )
132
+ )
133
+ for _ in range(num_feature_levels - num_backbone_outs):
134
+ input_proj_list.append(
135
+ nn.Sequential(
136
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
137
+ nn.GroupNorm(32, hidden_dim),
138
+ )
139
+ )
140
+ in_channels = hidden_dim
141
+ self.input_proj = nn.ModuleList(input_proj_list)
142
+ else:
143
+ assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
144
+ self.input_proj = nn.ModuleList(
145
+ [
146
+ nn.Sequential(
147
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
148
+ nn.GroupNorm(32, hidden_dim),
149
+ )
150
+ ]
151
+ )
152
+
153
+ self.backbone = backbone
154
+ self.aux_loss = aux_loss
155
+ self.box_pred_damping = box_pred_damping = None
156
+
157
+ self.iter_update = iter_update
158
+ assert iter_update, "Why not iter_update?"
159
+
160
+ # prepare pred layers
161
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
162
+ # prepare class & box embed
163
+ _class_embed = ContrastiveEmbed()
164
+
165
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
166
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
167
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
168
+
169
+ if dec_pred_bbox_embed_share:
170
+ box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
171
+ else:
172
+ box_embed_layerlist = [
173
+ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
174
+ ]
175
+ class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
176
+ self.bbox_embed = nn.ModuleList(box_embed_layerlist)
177
+ self.class_embed = nn.ModuleList(class_embed_layerlist)
178
+ self.transformer.decoder.bbox_embed = self.bbox_embed
179
+ self.transformer.decoder.class_embed = self.class_embed
180
+
181
+ # two stage
182
+ self.two_stage_type = two_stage_type
183
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
184
+ two_stage_type
185
+ )
186
+ if two_stage_type != "no":
187
+ if two_stage_bbox_embed_share:
188
+ assert dec_pred_bbox_embed_share
189
+ self.transformer.enc_out_bbox_embed = _bbox_embed
190
+ else:
191
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
192
+
193
+ if two_stage_class_embed_share:
194
+ assert dec_pred_bbox_embed_share
195
+ self.transformer.enc_out_class_embed = _class_embed
196
+ else:
197
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
198
+
199
+ self.refpoint_embed = None
200
+
201
+ self._reset_parameters()
202
+
203
+ def _reset_parameters(self):
204
+ # init input_proj
205
+ for proj in self.input_proj:
206
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
207
+ nn.init.constant_(proj[0].bias, 0)
208
+
209
+ def init_ref_points(self, use_num_queries):
210
+ self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
211
+
212
+ def forward(self, samples: NestedTensor, targets: List = None, **kw):
213
+ """The forward expects a NestedTensor, which consists of:
214
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
215
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
216
+
217
+ It returns a dict with the following elements:
218
+ - "pred_logits": the classification logits (including no-object) for all queries.
219
+ Shape= [batch_size x num_queries x num_classes]
220
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
221
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
222
+ relative to the size of each individual image (disregarding possible padding).
223
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
224
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
225
+ dictionnaries containing the two above keys for each decoder layer.
226
+ """
227
+ if targets is None:
228
+ captions = kw["captions"]
229
+ else:
230
+ captions = [t["caption"] for t in targets]
231
+ len(captions)
232
+
233
+ # encoder texts
234
+ tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
235
+ samples.device
236
+ )
237
+ (
238
+ text_self_attention_masks,
239
+ position_ids,
240
+ cate_to_token_mask_list,
241
+ ) = generate_masks_with_special_tokens_and_transfer_map(
242
+ tokenized, self.specical_tokens, self.tokenizer
243
+ )
244
+
245
+ if text_self_attention_masks.shape[1] > self.max_text_len:
246
+ text_self_attention_masks = text_self_attention_masks[
247
+ :, : self.max_text_len, : self.max_text_len
248
+ ]
249
+ position_ids = position_ids[:, : self.max_text_len]
250
+ tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
251
+ tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
252
+ tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
253
+
254
+ # extract text embeddings
255
+ if self.sub_sentence_present:
256
+ tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
257
+ tokenized_for_encoder["attention_mask"] = text_self_attention_masks
258
+ tokenized_for_encoder["position_ids"] = position_ids
259
+ else:
260
+ # import ipdb; ipdb.set_trace()
261
+ tokenized_for_encoder = tokenized
262
+
263
+ bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
264
+
265
+ encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
266
+ text_token_mask = tokenized.attention_mask.bool() # bs, 195
267
+ # text_token_mask: True for nomask, False for mask
268
+ # text_self_attention_masks: True for nomask, False for mask
269
+
270
+ if encoded_text.shape[1] > self.max_text_len:
271
+ encoded_text = encoded_text[:, : self.max_text_len, :]
272
+ text_token_mask = text_token_mask[:, : self.max_text_len]
273
+ position_ids = position_ids[:, : self.max_text_len]
274
+ text_self_attention_masks = text_self_attention_masks[
275
+ :, : self.max_text_len, : self.max_text_len
276
+ ]
277
+
278
+ text_dict = {
279
+ "encoded_text": encoded_text, # bs, 195, d_model
280
+ "text_token_mask": text_token_mask, # bs, 195
281
+ "position_ids": position_ids, # bs, 195
282
+ "text_self_attention_masks": text_self_attention_masks, # bs, 195,195
283
+ }
284
+
285
+ # import ipdb; ipdb.set_trace()
286
+
287
+ if isinstance(samples, (list, torch.Tensor)):
288
+ samples = nested_tensor_from_tensor_list(samples)
289
+ features, poss = self.backbone(samples)
290
+
291
+ srcs = []
292
+ masks = []
293
+ for l, feat in enumerate(features):
294
+ src, mask = feat.decompose()
295
+ srcs.append(self.input_proj[l](src))
296
+ masks.append(mask)
297
+ assert mask is not None
298
+ if self.num_feature_levels > len(srcs):
299
+ _len_srcs = len(srcs)
300
+ for l in range(_len_srcs, self.num_feature_levels):
301
+ if l == _len_srcs:
302
+ src = self.input_proj[l](features[-1].tensors)
303
+ else:
304
+ src = self.input_proj[l](srcs[-1])
305
+ m = samples.mask
306
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
307
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
308
+ srcs.append(src)
309
+ masks.append(mask)
310
+ poss.append(pos_l)
311
+
312
+ input_query_bbox = input_query_label = attn_mask = dn_meta = None
313
+ hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
314
+ srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
315
+ )
316
+
317
+ # deformable-detr-like anchor update
318
+ outputs_coord_list = []
319
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
320
+ zip(reference[:-1], self.bbox_embed, hs)
321
+ ):
322
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
323
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
324
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
325
+ outputs_coord_list.append(layer_outputs_unsig)
326
+ outputs_coord_list = torch.stack(outputs_coord_list)
327
+
328
+ # output
329
+ outputs_class = torch.stack(
330
+ [
331
+ layer_cls_embed(layer_hs, text_dict)
332
+ for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
333
+ ]
334
+ )
335
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
336
+
337
+ # # for intermediate outputs
338
+ # if self.aux_loss:
339
+ # out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
340
+
341
+ # # for encoder output
342
+ # if hs_enc is not None:
343
+ # # prepare intermediate outputs
344
+ # interm_coord = ref_enc[-1]
345
+ # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
346
+ # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
347
+ # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
348
+
349
+ return out
350
+
351
+ @torch.jit.unused
352
+ def _set_aux_loss(self, outputs_class, outputs_coord):
353
+ # this is a workaround to make torchscript happy, as torchscript
354
+ # doesn't support dictionary with non-homogeneous values, such
355
+ # as a dict having both a Tensor and a list.
356
+ return [
357
+ {"pred_logits": a, "pred_boxes": b}
358
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
359
+ ]
360
+
361
+
362
+ @MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
363
+ def build_groundingdino(args):
364
+
365
+ backbone = build_backbone(args)
366
+ transformer = build_transformer(args)
367
+
368
+ dn_labelbook_size = args.dn_labelbook_size
369
+ dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
370
+ sub_sentence_present = args.sub_sentence_present
371
+
372
+ model = GroundingDINO(
373
+ backbone,
374
+ transformer,
375
+ num_queries=args.num_queries,
376
+ aux_loss=True,
377
+ iter_update=True,
378
+ query_dim=4,
379
+ num_feature_levels=args.num_feature_levels,
380
+ nheads=args.nheads,
381
+ dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
382
+ two_stage_type=args.two_stage_type,
383
+ two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
384
+ two_stage_class_embed_share=args.two_stage_class_embed_share,
385
+ num_patterns=args.num_patterns,
386
+ dn_number=0,
387
+ dn_box_noise_scale=args.dn_box_noise_scale,
388
+ dn_label_noise_ratio=args.dn_label_noise_ratio,
389
+ dn_labelbook_size=dn_labelbook_size,
390
+ text_encoder_type=args.text_encoder_type,
391
+ sub_sentence_present=sub_sentence_present,
392
+ max_text_len=args.max_text_len,
393
+ )
394
+
395
+ return model
GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Deformable DETR
8
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------------------------------
11
+ # Modified from:
12
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
13
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
14
+ # https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
15
+ # ------------------------------------------------------------------------------------------------
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.autograd import Function
25
+ from torch.autograd.function import once_differentiable
26
+ from torch.nn.init import constant_, xavier_uniform_
27
+
28
+ try:
29
+ from groundingdino import _C
30
+ except:
31
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
32
+
33
+
34
+ # helpers
35
+ def _is_power_of_2(n):
36
+ if (not isinstance(n, int)) or (n < 0):
37
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
38
+ return (n & (n - 1) == 0) and n != 0
39
+
40
+
41
+ class MultiScaleDeformableAttnFunction(Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ value,
46
+ value_spatial_shapes,
47
+ value_level_start_index,
48
+ sampling_locations,
49
+ attention_weights,
50
+ im2col_step,
51
+ ):
52
+ ctx.im2col_step = im2col_step
53
+ output = _C.ms_deform_attn_forward(
54
+ value,
55
+ value_spatial_shapes,
56
+ value_level_start_index,
57
+ sampling_locations,
58
+ attention_weights,
59
+ ctx.im2col_step,
60
+ )
61
+ ctx.save_for_backward(
62
+ value,
63
+ value_spatial_shapes,
64
+ value_level_start_index,
65
+ sampling_locations,
66
+ attention_weights,
67
+ )
68
+ return output
69
+
70
+ @staticmethod
71
+ @once_differentiable
72
+ def backward(ctx, grad_output):
73
+ (
74
+ value,
75
+ value_spatial_shapes,
76
+ value_level_start_index,
77
+ sampling_locations,
78
+ attention_weights,
79
+ ) = ctx.saved_tensors
80
+ grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
81
+ value,
82
+ value_spatial_shapes,
83
+ value_level_start_index,
84
+ sampling_locations,
85
+ attention_weights,
86
+ grad_output,
87
+ ctx.im2col_step,
88
+ )
89
+
90
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
91
+
92
+
93
+ def multi_scale_deformable_attn_pytorch(
94
+ value: torch.Tensor,
95
+ value_spatial_shapes: torch.Tensor,
96
+ sampling_locations: torch.Tensor,
97
+ attention_weights: torch.Tensor,
98
+ ) -> torch.Tensor:
99
+
100
+ bs, _, num_heads, embed_dims = value.shape
101
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
102
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
103
+ sampling_grids = 2 * sampling_locations - 1
104
+ sampling_value_list = []
105
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
106
+ # bs, H_*W_, num_heads, embed_dims ->
107
+ # bs, H_*W_, num_heads*embed_dims ->
108
+ # bs, num_heads*embed_dims, H_*W_ ->
109
+ # bs*num_heads, embed_dims, H_, W_
110
+ value_l_ = (
111
+ value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
112
+ )
113
+ # bs, num_queries, num_heads, num_points, 2 ->
114
+ # bs, num_heads, num_queries, num_points, 2 ->
115
+ # bs*num_heads, num_queries, num_points, 2
116
+ sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
117
+ # bs*num_heads, embed_dims, num_queries, num_points
118
+ sampling_value_l_ = F.grid_sample(
119
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
120
+ )
121
+ sampling_value_list.append(sampling_value_l_)
122
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
123
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
124
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
125
+ attention_weights = attention_weights.transpose(1, 2).reshape(
126
+ bs * num_heads, 1, num_queries, num_levels * num_points
127
+ )
128
+ output = (
129
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
130
+ .sum(-1)
131
+ .view(bs, num_heads * embed_dims, num_queries)
132
+ )
133
+ return output.transpose(1, 2).contiguous()
134
+
135
+
136
+ class MultiScaleDeformableAttention(nn.Module):
137
+ """Multi-Scale Deformable Attention Module used in Deformable-DETR
138
+
139
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
140
+ <https://arxiv.org/pdf/2010.04159.pdf>`_.
141
+
142
+ Args:
143
+ embed_dim (int): The embedding dimension of Attention. Default: 256.
144
+ num_heads (int): The number of attention heads. Default: 8.
145
+ num_levels (int): The number of feature map used in Attention. Default: 4.
146
+ num_points (int): The number of sampling points for each query
147
+ in each head. Default: 4.
148
+ img2col_steps (int): The step used in image_to_column. Defualt: 64.
149
+ dropout (float): Dropout layer used in output. Default: 0.1.
150
+ batch_first (bool): if ``True``, then the input and output tensor will be
151
+ provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ embed_dim: int = 256,
157
+ num_heads: int = 8,
158
+ num_levels: int = 4,
159
+ num_points: int = 4,
160
+ img2col_step: int = 64,
161
+ batch_first: bool = False,
162
+ ):
163
+ super().__init__()
164
+ if embed_dim % num_heads != 0:
165
+ raise ValueError(
166
+ "embed_dim must be divisible by num_heads, but got {} and {}".format(
167
+ embed_dim, num_heads
168
+ )
169
+ )
170
+ head_dim = embed_dim // num_heads
171
+
172
+ self.batch_first = batch_first
173
+
174
+ if not _is_power_of_2(head_dim):
175
+ warnings.warn(
176
+ """
177
+ You'd better set d_model in MSDeformAttn to make sure that
178
+ each dim of the attention head a power of 2, which is more efficient.
179
+ """
180
+ )
181
+
182
+ self.im2col_step = img2col_step
183
+ self.embed_dim = embed_dim
184
+ self.num_heads = num_heads
185
+ self.num_levels = num_levels
186
+ self.num_points = num_points
187
+ self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
188
+ self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
189
+ self.value_proj = nn.Linear(embed_dim, embed_dim)
190
+ self.output_proj = nn.Linear(embed_dim, embed_dim)
191
+
192
+ self.init_weights()
193
+
194
+ def _reset_parameters(self):
195
+ return self.init_weights()
196
+
197
+ def init_weights(self):
198
+ """
199
+ Default initialization for Parameters of Module.
200
+ """
201
+ constant_(self.sampling_offsets.weight.data, 0.0)
202
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
203
+ 2.0 * math.pi / self.num_heads
204
+ )
205
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
206
+ grid_init = (
207
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
208
+ .view(self.num_heads, 1, 1, 2)
209
+ .repeat(1, self.num_levels, self.num_points, 1)
210
+ )
211
+ for i in range(self.num_points):
212
+ grid_init[:, :, i, :] *= i + 1
213
+ with torch.no_grad():
214
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
215
+ constant_(self.attention_weights.weight.data, 0.0)
216
+ constant_(self.attention_weights.bias.data, 0.0)
217
+ xavier_uniform_(self.value_proj.weight.data)
218
+ constant_(self.value_proj.bias.data, 0.0)
219
+ xavier_uniform_(self.output_proj.weight.data)
220
+ constant_(self.output_proj.bias.data, 0.0)
221
+
222
+ def freeze_sampling_offsets(self):
223
+ print("Freeze sampling offsets")
224
+ self.sampling_offsets.weight.requires_grad = False
225
+ self.sampling_offsets.bias.requires_grad = False
226
+
227
+ def freeze_attention_weights(self):
228
+ print("Freeze attention weights")
229
+ self.attention_weights.weight.requires_grad = False
230
+ self.attention_weights.bias.requires_grad = False
231
+
232
+ def forward(
233
+ self,
234
+ query: torch.Tensor,
235
+ key: Optional[torch.Tensor] = None,
236
+ value: Optional[torch.Tensor] = None,
237
+ query_pos: Optional[torch.Tensor] = None,
238
+ key_padding_mask: Optional[torch.Tensor] = None,
239
+ reference_points: Optional[torch.Tensor] = None,
240
+ spatial_shapes: Optional[torch.Tensor] = None,
241
+ level_start_index: Optional[torch.Tensor] = None,
242
+ **kwargs
243
+ ) -> torch.Tensor:
244
+
245
+ """Forward Function of MultiScaleDeformableAttention
246
+
247
+ Args:
248
+ query (torch.Tensor): Query embeddings with shape
249
+ `(num_query, bs, embed_dim)`
250
+ key (torch.Tensor): Key embeddings with shape
251
+ `(num_key, bs, embed_dim)`
252
+ value (torch.Tensor): Value embeddings with shape
253
+ `(num_key, bs, embed_dim)`
254
+ query_pos (torch.Tensor): The position embedding for `query`. Default: None.
255
+ key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
256
+ indicating which elements within `key` to be ignored in attention.
257
+ reference_points (torch.Tensor): The normalized reference points
258
+ with shape `(bs, num_query, num_levels, 2)`,
259
+ all elements is range in [0, 1], top-left (0, 0),
260
+ bottom-right (1, 1), including padding are.
261
+ or `(N, Length_{query}, num_levels, 4)`, add additional
262
+ two dimensions `(h, w)` to form reference boxes.
263
+ spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
264
+ With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
265
+ level_start_index (torch.Tensor): The start index of each level. A tensor with
266
+ shape `(num_levels, )` which can be represented as
267
+ `[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
268
+
269
+ Returns:
270
+ torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
271
+ """
272
+
273
+ if value is None:
274
+ value = query
275
+
276
+ if query_pos is not None:
277
+ query = query + query_pos
278
+
279
+ if not self.batch_first:
280
+ # change to (bs, num_query ,embed_dims)
281
+ query = query.permute(1, 0, 2)
282
+ value = value.permute(1, 0, 2)
283
+
284
+ bs, num_query, _ = query.shape
285
+ bs, num_value, _ = value.shape
286
+
287
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
288
+
289
+ value = self.value_proj(value)
290
+ if key_padding_mask is not None:
291
+ value = value.masked_fill(key_padding_mask[..., None], float(0))
292
+ value = value.view(bs, num_value, self.num_heads, -1)
293
+ sampling_offsets = self.sampling_offsets(query).view(
294
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
295
+ )
296
+ attention_weights = self.attention_weights(query).view(
297
+ bs, num_query, self.num_heads, self.num_levels * self.num_points
298
+ )
299
+ attention_weights = attention_weights.softmax(-1)
300
+ attention_weights = attention_weights.view(
301
+ bs,
302
+ num_query,
303
+ self.num_heads,
304
+ self.num_levels,
305
+ self.num_points,
306
+ )
307
+
308
+ # bs, num_query, num_heads, num_levels, num_points, 2
309
+ if reference_points.shape[-1] == 2:
310
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
311
+ sampling_locations = (
312
+ reference_points[:, :, None, :, None, :]
313
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
314
+ )
315
+ elif reference_points.shape[-1] == 4:
316
+ sampling_locations = (
317
+ reference_points[:, :, None, :, None, :2]
318
+ + sampling_offsets
319
+ / self.num_points
320
+ * reference_points[:, :, None, :, None, 2:]
321
+ * 0.5
322
+ )
323
+ else:
324
+ raise ValueError(
325
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
326
+ reference_points.shape[-1]
327
+ )
328
+ )
329
+
330
+ if torch.cuda.is_available() and value.is_cuda:
331
+ halffloat = False
332
+ if value.dtype == torch.float16:
333
+ halffloat = True
334
+ value = value.float()
335
+ sampling_locations = sampling_locations.float()
336
+ attention_weights = attention_weights.float()
337
+
338
+ output = MultiScaleDeformableAttnFunction.apply(
339
+ value,
340
+ spatial_shapes,
341
+ level_start_index,
342
+ sampling_locations,
343
+ attention_weights,
344
+ self.im2col_step,
345
+ )
346
+
347
+ if halffloat:
348
+ output = output.half()
349
+ else:
350
+ output = multi_scale_deformable_attn_pytorch(
351
+ value, spatial_shapes, sampling_locations, attention_weights
352
+ )
353
+
354
+ output = self.output_proj(output)
355
+
356
+ if not self.batch_first:
357
+ output = output.permute(1, 0, 2)
358
+
359
+ return output
360
+
361
+
362
+ def create_dummy_class(klass, dependency, message=""):
363
+ """
364
+ When a dependency of a class is not available, create a dummy class which throws ImportError
365
+ when used.
366
+
367
+ Args:
368
+ klass (str): name of the class.
369
+ dependency (str): name of the dependency.
370
+ message: extra message to print
371
+ Returns:
372
+ class: a class object
373
+ """
374
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
375
+ if message:
376
+ err = err + " " + message
377
+
378
+ class _DummyMetaClass(type):
379
+ # throw error on class attribute access
380
+ def __getattr__(_, __): # noqa: B902
381
+ raise ImportError(err)
382
+
383
+ class _Dummy(object, metaclass=_DummyMetaClass):
384
+ # throw error on constructor
385
+ def __init__(self, *args, **kwargs):
386
+ raise ImportError(err)
387
+
388
+ return _Dummy
389
+
390
+
391
+ def create_dummy_func(func, dependency, message=""):
392
+ """
393
+ When a dependency of a function is not available, create a dummy function which throws
394
+ ImportError when used.
395
+
396
+ Args:
397
+ func (str): name of the function.
398
+ dependency (str or list[str]): name(s) of the dependency.
399
+ message: extra message to print
400
+ Returns:
401
+ function: a function object
402
+ """
403
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
404
+ if message:
405
+ err = err + " " + message
406
+
407
+ if isinstance(dependency, (list, tuple)):
408
+ dependency = ",".join(dependency)
409
+
410
+ def _dummy(*args, **kwargs):
411
+ raise ImportError(err)
412
+
413
+ return _dummy