unknown commited on
Commit
24a2388
1 Parent(s): 83529fe
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. GroundingDINO/LICENSE +201 -0
  2. GroundingDINO/groundingdino/__init__.py +0 -0
  3. GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py +43 -0
  4. GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py +43 -0
  5. GroundingDINO/groundingdino/datasets/__init__.py +0 -0
  6. GroundingDINO/groundingdino/datasets/transforms.py +311 -0
  7. GroundingDINO/groundingdino/models/GroundingDINO/__init__.py +15 -0
  8. GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py +1 -0
  9. GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py +221 -0
  10. GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py +186 -0
  11. {models → GroundingDINO/groundingdino/models/GroundingDINO/backbone}/swin_transformer.py +488 -340
  12. GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py +273 -0
  13. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h +64 -0
  14. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp +43 -0
  15. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h +35 -0
  16. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +156 -0
  17. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h +33 -0
  18. GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh +1327 -0
  19. GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu +7 -0
  20. GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp +58 -0
  21. GroundingDINO/groundingdino/models/GroundingDINO/fuse_modules.py +297 -0
  22. GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py +395 -0
  23. GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py +413 -0
  24. GroundingDINO/groundingdino/models/GroundingDINO/transformer.py +959 -0
  25. GroundingDINO/groundingdino/models/GroundingDINO/transformer_vanilla.py +123 -0
  26. GroundingDINO/groundingdino/models/GroundingDINO/utils.py +268 -0
  27. GroundingDINO/groundingdino/models/__init__.py +18 -0
  28. GroundingDINO/groundingdino/models/registry.py +66 -0
  29. GroundingDINO/groundingdino/util/__init__.py +1 -0
  30. GroundingDINO/groundingdino/util/box_ops.py +140 -0
  31. GroundingDINO/groundingdino/util/get_tokenlizer.py +26 -0
  32. GroundingDINO/groundingdino/util/inference.py +257 -0
  33. GroundingDINO/groundingdino/util/logger.py +93 -0
  34. GroundingDINO/groundingdino/util/misc.py +717 -0
  35. GroundingDINO/groundingdino/util/slconfig.py +427 -0
  36. GroundingDINO/groundingdino/util/slio.py +177 -0
  37. GroundingDINO/groundingdino/util/time_counter.py +62 -0
  38. GroundingDINO/groundingdino/util/utils.py +608 -0
  39. GroundingDINO/groundingdino/util/visualizer.py +318 -0
  40. GroundingDINO/groundingdino/util/vl_utils.py +100 -0
  41. GroundingDINO/groundingdino/version.py +1 -0
  42. GroundingDINO/setup.py +216 -0
  43. README.md +11 -178
  44. app.py +371 -193
  45. configs/med_config.json +0 -21
  46. configs/q2l_config.json +0 -22
  47. configs/swin/config_swinB_224.json +0 -10
  48. configs/swin/config_swinB_384.json +0 -10
  49. configs/swin/config_swinB_480.json +0 -9
  50. configs/swin/config_swinB_576.json +0 -9
GroundingDINO/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 - present, Facebook, Inc
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
GroundingDINO/groundingdino/__init__.py ADDED
File without changes
GroundingDINO/groundingdino/config/GroundingDINO_SwinB.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_B_384_22k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size = 1
2
+ modelname = "groundingdino"
3
+ backbone = "swin_T_224_1k"
4
+ position_embedding = "sine"
5
+ pe_temperatureH = 20
6
+ pe_temperatureW = 20
7
+ return_interm_indices = [1, 2, 3]
8
+ backbone_freeze_keywords = None
9
+ enc_layers = 6
10
+ dec_layers = 6
11
+ pre_norm = False
12
+ dim_feedforward = 2048
13
+ hidden_dim = 256
14
+ dropout = 0.0
15
+ nheads = 8
16
+ num_queries = 900
17
+ query_dim = 4
18
+ num_patterns = 0
19
+ num_feature_levels = 4
20
+ enc_n_points = 4
21
+ dec_n_points = 4
22
+ two_stage_type = "standard"
23
+ two_stage_bbox_embed_share = False
24
+ two_stage_class_embed_share = False
25
+ transformer_activation = "relu"
26
+ dec_pred_bbox_embed_share = True
27
+ dn_box_noise_scale = 1.0
28
+ dn_label_noise_ratio = 0.5
29
+ dn_label_coef = 1.0
30
+ dn_bbox_coef = 1.0
31
+ embed_init_tgt = True
32
+ dn_labelbook_size = 2000
33
+ max_text_len = 256
34
+ text_encoder_type = "bert-base-uncased"
35
+ use_text_enhancer = True
36
+ use_fusion_layer = True
37
+ use_checkpoint = True
38
+ use_transformer_ckpt = True
39
+ use_text_cross_attention = True
40
+ text_dropout = 0.0
41
+ fusion_dropout = 0.0
42
+ fusion_droppath = 0.1
43
+ sub_sentence_present = True
GroundingDINO/groundingdino/datasets/__init__.py ADDED
File without changes
GroundingDINO/groundingdino/datasets/transforms.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Transforms and data augmentation for both image + bbox.
4
+ """
5
+ import os
6
+ import random
7
+
8
+ import PIL
9
+ import torch
10
+ import torchvision.transforms as T
11
+ import torchvision.transforms.functional as F
12
+
13
+ from groundingdino.util.box_ops import box_xyxy_to_cxcywh
14
+ from groundingdino.util.misc import interpolate
15
+
16
+
17
+ def crop(image, target, region):
18
+ cropped_image = F.crop(image, *region)
19
+
20
+ target = target.copy()
21
+ i, j, h, w = region
22
+
23
+ # should we do something wrt the original size?
24
+ target["size"] = torch.tensor([h, w])
25
+
26
+ fields = ["labels", "area", "iscrowd", "positive_map"]
27
+
28
+ if "boxes" in target:
29
+ boxes = target["boxes"]
30
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
31
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
32
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
33
+ cropped_boxes = cropped_boxes.clamp(min=0)
34
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
35
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
36
+ target["area"] = area
37
+ fields.append("boxes")
38
+
39
+ if "masks" in target:
40
+ # FIXME should we update the area here if there are no boxes?
41
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
42
+ fields.append("masks")
43
+
44
+ # remove elements for which the boxes or masks that have zero area
45
+ if "boxes" in target or "masks" in target:
46
+ # favor boxes selection when defining which elements to keep
47
+ # this is compatible with previous implementation
48
+ if "boxes" in target:
49
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
50
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
51
+ else:
52
+ keep = target["masks"].flatten(1).any(1)
53
+
54
+ for field in fields:
55
+ if field in target:
56
+ target[field] = target[field][keep]
57
+
58
+ if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
59
+ # for debug and visualization only.
60
+ if "strings_positive" in target:
61
+ target["strings_positive"] = [
62
+ _i for _i, _j in zip(target["strings_positive"], keep) if _j
63
+ ]
64
+
65
+ return cropped_image, target
66
+
67
+
68
+ def hflip(image, target):
69
+ flipped_image = F.hflip(image)
70
+
71
+ w, h = image.size
72
+
73
+ target = target.copy()
74
+ if "boxes" in target:
75
+ boxes = target["boxes"]
76
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
77
+ [w, 0, w, 0]
78
+ )
79
+ target["boxes"] = boxes
80
+
81
+ if "masks" in target:
82
+ target["masks"] = target["masks"].flip(-1)
83
+
84
+ return flipped_image, target
85
+
86
+
87
+ def resize(image, target, size, max_size=None):
88
+ # size can be min_size (scalar) or (w, h) tuple
89
+
90
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
91
+ w, h = image_size
92
+ if max_size is not None:
93
+ min_original_size = float(min((w, h)))
94
+ max_original_size = float(max((w, h)))
95
+ if max_original_size / min_original_size * size > max_size:
96
+ size = int(round(max_size * min_original_size / max_original_size))
97
+
98
+ if (w <= h and w == size) or (h <= w and h == size):
99
+ return (h, w)
100
+
101
+ if w < h:
102
+ ow = size
103
+ oh = int(size * h / w)
104
+ else:
105
+ oh = size
106
+ ow = int(size * w / h)
107
+
108
+ return (oh, ow)
109
+
110
+ def get_size(image_size, size, max_size=None):
111
+ if isinstance(size, (list, tuple)):
112
+ return size[::-1]
113
+ else:
114
+ return get_size_with_aspect_ratio(image_size, size, max_size)
115
+
116
+ size = get_size(image.size, size, max_size)
117
+ rescaled_image = F.resize(image, size)
118
+
119
+ if target is None:
120
+ return rescaled_image, None
121
+
122
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
123
+ ratio_width, ratio_height = ratios
124
+
125
+ target = target.copy()
126
+ if "boxes" in target:
127
+ boxes = target["boxes"]
128
+ scaled_boxes = boxes * torch.as_tensor(
129
+ [ratio_width, ratio_height, ratio_width, ratio_height]
130
+ )
131
+ target["boxes"] = scaled_boxes
132
+
133
+ if "area" in target:
134
+ area = target["area"]
135
+ scaled_area = area * (ratio_width * ratio_height)
136
+ target["area"] = scaled_area
137
+
138
+ h, w = size
139
+ target["size"] = torch.tensor([h, w])
140
+
141
+ if "masks" in target:
142
+ target["masks"] = (
143
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
144
+ )
145
+
146
+ return rescaled_image, target
147
+
148
+
149
+ def pad(image, target, padding):
150
+ # assumes that we only pad on the bottom right corners
151
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
152
+ if target is None:
153
+ return padded_image, None
154
+ target = target.copy()
155
+ # should we do something wrt the original size?
156
+ target["size"] = torch.tensor(padded_image.size[::-1])
157
+ if "masks" in target:
158
+ target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
159
+ return padded_image, target
160
+
161
+
162
+ class ResizeDebug(object):
163
+ def __init__(self, size):
164
+ self.size = size
165
+
166
+ def __call__(self, img, target):
167
+ return resize(img, target, self.size)
168
+
169
+
170
+ class RandomCrop(object):
171
+ def __init__(self, size):
172
+ self.size = size
173
+
174
+ def __call__(self, img, target):
175
+ region = T.RandomCrop.get_params(img, self.size)
176
+ return crop(img, target, region)
177
+
178
+
179
+ class RandomSizeCrop(object):
180
+ def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
181
+ # respect_boxes: True to keep all boxes
182
+ # False to tolerence box filter
183
+ self.min_size = min_size
184
+ self.max_size = max_size
185
+ self.respect_boxes = respect_boxes
186
+
187
+ def __call__(self, img: PIL.Image.Image, target: dict):
188
+ init_boxes = len(target["boxes"])
189
+ max_patience = 10
190
+ for i in range(max_patience):
191
+ w = random.randint(self.min_size, min(img.width, self.max_size))
192
+ h = random.randint(self.min_size, min(img.height, self.max_size))
193
+ region = T.RandomCrop.get_params(img, [h, w])
194
+ result_img, result_target = crop(img, target, region)
195
+ if (
196
+ not self.respect_boxes
197
+ or len(result_target["boxes"]) == init_boxes
198
+ or i == max_patience - 1
199
+ ):
200
+ return result_img, result_target
201
+ return result_img, result_target
202
+
203
+
204
+ class CenterCrop(object):
205
+ def __init__(self, size):
206
+ self.size = size
207
+
208
+ def __call__(self, img, target):
209
+ image_width, image_height = img.size
210
+ crop_height, crop_width = self.size
211
+ crop_top = int(round((image_height - crop_height) / 2.0))
212
+ crop_left = int(round((image_width - crop_width) / 2.0))
213
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
214
+
215
+
216
+ class RandomHorizontalFlip(object):
217
+ def __init__(self, p=0.5):
218
+ self.p = p
219
+
220
+ def __call__(self, img, target):
221
+ if random.random() < self.p:
222
+ return hflip(img, target)
223
+ return img, target
224
+
225
+
226
+ class RandomResize(object):
227
+ def __init__(self, sizes, max_size=None):
228
+ assert isinstance(sizes, (list, tuple))
229
+ self.sizes = sizes
230
+ self.max_size = max_size
231
+
232
+ def __call__(self, img, target=None):
233
+ size = random.choice(self.sizes)
234
+ return resize(img, target, size, self.max_size)
235
+
236
+
237
+ class RandomPad(object):
238
+ def __init__(self, max_pad):
239
+ self.max_pad = max_pad
240
+
241
+ def __call__(self, img, target):
242
+ pad_x = random.randint(0, self.max_pad)
243
+ pad_y = random.randint(0, self.max_pad)
244
+ return pad(img, target, (pad_x, pad_y))
245
+
246
+
247
+ class RandomSelect(object):
248
+ """
249
+ Randomly selects between transforms1 and transforms2,
250
+ with probability p for transforms1 and (1 - p) for transforms2
251
+ """
252
+
253
+ def __init__(self, transforms1, transforms2, p=0.5):
254
+ self.transforms1 = transforms1
255
+ self.transforms2 = transforms2
256
+ self.p = p
257
+
258
+ def __call__(self, img, target):
259
+ if random.random() < self.p:
260
+ return self.transforms1(img, target)
261
+ return self.transforms2(img, target)
262
+
263
+
264
+ class ToTensor(object):
265
+ def __call__(self, img, target):
266
+ return F.to_tensor(img), target
267
+
268
+
269
+ class RandomErasing(object):
270
+ def __init__(self, *args, **kwargs):
271
+ self.eraser = T.RandomErasing(*args, **kwargs)
272
+
273
+ def __call__(self, img, target):
274
+ return self.eraser(img), target
275
+
276
+
277
+ class Normalize(object):
278
+ def __init__(self, mean, std):
279
+ self.mean = mean
280
+ self.std = std
281
+
282
+ def __call__(self, image, target=None):
283
+ image = F.normalize(image, mean=self.mean, std=self.std)
284
+ if target is None:
285
+ return image, None
286
+ target = target.copy()
287
+ h, w = image.shape[-2:]
288
+ if "boxes" in target:
289
+ boxes = target["boxes"]
290
+ boxes = box_xyxy_to_cxcywh(boxes)
291
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
292
+ target["boxes"] = boxes
293
+ return image, target
294
+
295
+
296
+ class Compose(object):
297
+ def __init__(self, transforms):
298
+ self.transforms = transforms
299
+
300
+ def __call__(self, image, target):
301
+ for t in self.transforms:
302
+ image, target = t(image, target)
303
+ return image, target
304
+
305
+ def __repr__(self):
306
+ format_string = self.__class__.__name__ + "("
307
+ for t in self.transforms:
308
+ format_string += "\n"
309
+ format_string += " {0}".format(t)
310
+ format_string += "\n)"
311
+ return format_string
GroundingDINO/groundingdino/models/GroundingDINO/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ from .groundingdino import build_groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .backbone import build_backbone
GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ """
16
+ Backbone modules.
17
+ """
18
+
19
+ from typing import Dict, List
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchvision
24
+ from torch import nn
25
+ from torchvision.models._utils import IntermediateLayerGetter
26
+
27
+ from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
28
+
29
+ from .position_encoding import build_position_encoding
30
+ from .swin_transformer import build_swin_transformer
31
+
32
+
33
+ class FrozenBatchNorm2d(torch.nn.Module):
34
+ """
35
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
36
+
37
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
38
+ without which any other models than torchvision.models.resnet[18,34,50,101]
39
+ produce nans.
40
+ """
41
+
42
+ def __init__(self, n):
43
+ super(FrozenBatchNorm2d, self).__init__()
44
+ self.register_buffer("weight", torch.ones(n))
45
+ self.register_buffer("bias", torch.zeros(n))
46
+ self.register_buffer("running_mean", torch.zeros(n))
47
+ self.register_buffer("running_var", torch.ones(n))
48
+
49
+ def _load_from_state_dict(
50
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
51
+ ):
52
+ num_batches_tracked_key = prefix + "num_batches_tracked"
53
+ if num_batches_tracked_key in state_dict:
54
+ del state_dict[num_batches_tracked_key]
55
+
56
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
57
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
58
+ )
59
+
60
+ def forward(self, x):
61
+ # move reshapes to the beginning
62
+ # to make it fuser-friendly
63
+ w = self.weight.reshape(1, -1, 1, 1)
64
+ b = self.bias.reshape(1, -1, 1, 1)
65
+ rv = self.running_var.reshape(1, -1, 1, 1)
66
+ rm = self.running_mean.reshape(1, -1, 1, 1)
67
+ eps = 1e-5
68
+ scale = w * (rv + eps).rsqrt()
69
+ bias = b - rm * scale
70
+ return x * scale + bias
71
+
72
+
73
+ class BackboneBase(nn.Module):
74
+ def __init__(
75
+ self,
76
+ backbone: nn.Module,
77
+ train_backbone: bool,
78
+ num_channels: int,
79
+ return_interm_indices: list,
80
+ ):
81
+ super().__init__()
82
+ for name, parameter in backbone.named_parameters():
83
+ if (
84
+ not train_backbone
85
+ or "layer2" not in name
86
+ and "layer3" not in name
87
+ and "layer4" not in name
88
+ ):
89
+ parameter.requires_grad_(False)
90
+
91
+ return_layers = {}
92
+ for idx, layer_index in enumerate(return_interm_indices):
93
+ return_layers.update(
94
+ {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
95
+ )
96
+
97
+ # if len:
98
+ # if use_stage1_feature:
99
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
100
+ # else:
101
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
102
+ # else:
103
+ # return_layers = {'layer4': "0"}
104
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
105
+ self.num_channels = num_channels
106
+
107
+ def forward(self, tensor_list: NestedTensor):
108
+ xs = self.body(tensor_list.tensors)
109
+ out: Dict[str, NestedTensor] = {}
110
+ for name, x in xs.items():
111
+ m = tensor_list.mask
112
+ assert m is not None
113
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
114
+ out[name] = NestedTensor(x, mask)
115
+ # import ipdb; ipdb.set_trace()
116
+ return out
117
+
118
+
119
+ class Backbone(BackboneBase):
120
+ """ResNet backbone with frozen BatchNorm."""
121
+
122
+ def __init__(
123
+ self,
124
+ name: str,
125
+ train_backbone: bool,
126
+ dilation: bool,
127
+ return_interm_indices: list,
128
+ batch_norm=FrozenBatchNorm2d,
129
+ ):
130
+ if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
131
+ backbone = getattr(torchvision.models, name)(
132
+ replace_stride_with_dilation=[False, False, dilation],
133
+ pretrained=is_main_process(),
134
+ norm_layer=batch_norm,
135
+ )
136
+ else:
137
+ raise NotImplementedError("Why you can get here with name {}".format(name))
138
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
139
+ assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
140
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
141
+ num_channels_all = [256, 512, 1024, 2048]
142
+ num_channels = num_channels_all[4 - len(return_interm_indices) :]
143
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
144
+
145
+
146
+ class Joiner(nn.Sequential):
147
+ def __init__(self, backbone, position_embedding):
148
+ super().__init__(backbone, position_embedding)
149
+
150
+ def forward(self, tensor_list: NestedTensor):
151
+ xs = self[0](tensor_list)
152
+ out: List[NestedTensor] = []
153
+ pos = []
154
+ for name, x in xs.items():
155
+ out.append(x)
156
+ # position encoding
157
+ pos.append(self[1](x).to(x.tensors.dtype))
158
+
159
+ return out, pos
160
+
161
+
162
+ def build_backbone(args):
163
+ """
164
+ Useful args:
165
+ - backbone: backbone name
166
+ - lr_backbone:
167
+ - dilation
168
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
169
+ - backbone_freeze_keywords:
170
+ - use_checkpoint: for swin only for now
171
+
172
+ """
173
+ position_embedding = build_position_encoding(args)
174
+ train_backbone = True
175
+ if not train_backbone:
176
+ raise ValueError("Please set lr_backbone > 0")
177
+ return_interm_indices = args.return_interm_indices
178
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
179
+ args.backbone_freeze_keywords
180
+ use_checkpoint = getattr(args, "use_checkpoint", False)
181
+
182
+ if args.backbone in ["resnet50", "resnet101"]:
183
+ backbone = Backbone(
184
+ args.backbone,
185
+ train_backbone,
186
+ args.dilation,
187
+ return_interm_indices,
188
+ batch_norm=FrozenBatchNorm2d,
189
+ )
190
+ bb_num_channels = backbone.num_channels
191
+ elif args.backbone in [
192
+ "swin_T_224_1k",
193
+ "swin_B_224_22k",
194
+ "swin_B_384_22k",
195
+ "swin_L_224_22k",
196
+ "swin_L_384_22k",
197
+ ]:
198
+ pretrain_img_size = int(args.backbone.split("_")[-2])
199
+ backbone = build_swin_transformer(
200
+ args.backbone,
201
+ pretrain_img_size=pretrain_img_size,
202
+ out_indices=tuple(return_interm_indices),
203
+ dilation=False,
204
+ use_checkpoint=use_checkpoint,
205
+ )
206
+
207
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
208
+ else:
209
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
210
+
211
+ assert len(bb_num_channels) == len(
212
+ return_interm_indices
213
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
214
+
215
+ model = Joiner(backbone, position_embedding)
216
+ model.num_channels = bb_num_channels
217
+ assert isinstance(
218
+ bb_num_channels, List
219
+ ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
220
+ # import ipdb; ipdb.set_trace()
221
+ return model
GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Copied from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ """
20
+ Various positional encodings for the transformer.
21
+ """
22
+ import math
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ from groundingdino.util.misc import NestedTensor
28
+
29
+
30
+ class PositionEmbeddingSine(nn.Module):
31
+ """
32
+ This is a more standard version of the position embedding, very similar to the one
33
+ used by the Attention is all you need paper, generalized to work on images.
34
+ """
35
+
36
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
37
+ super().__init__()
38
+ self.num_pos_feats = num_pos_feats
39
+ self.temperature = temperature
40
+ self.normalize = normalize
41
+ if scale is not None and normalize is False:
42
+ raise ValueError("normalize should be True if scale is passed")
43
+ if scale is None:
44
+ scale = 2 * math.pi
45
+ self.scale = scale
46
+
47
+ def forward(self, tensor_list: NestedTensor):
48
+ x = tensor_list.tensors
49
+ mask = tensor_list.mask
50
+ assert mask is not None
51
+ not_mask = ~mask
52
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
53
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
54
+ if self.normalize:
55
+ eps = 1e-6
56
+ # if os.environ.get("SHILONG_AMP", None) == '1':
57
+ # eps = 1e-4
58
+ # else:
59
+ # eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ return pos
76
+
77
+
78
+ class PositionEmbeddingSineHW(nn.Module):
79
+ """
80
+ This is a more standard version of the position embedding, very similar to the one
81
+ used by the Attention is all you need paper, generalized to work on images.
82
+ """
83
+
84
+ def __init__(
85
+ self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
86
+ ):
87
+ super().__init__()
88
+ self.num_pos_feats = num_pos_feats
89
+ self.temperatureH = temperatureH
90
+ self.temperatureW = temperatureW
91
+ self.normalize = normalize
92
+ if scale is not None and normalize is False:
93
+ raise ValueError("normalize should be True if scale is passed")
94
+ if scale is None:
95
+ scale = 2 * math.pi
96
+ self.scale = scale
97
+
98
+ def forward(self, tensor_list: NestedTensor):
99
+ x = tensor_list.tensors
100
+ mask = tensor_list.mask
101
+ assert mask is not None
102
+ not_mask = ~mask
103
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
104
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
105
+
106
+ # import ipdb; ipdb.set_trace()
107
+
108
+ if self.normalize:
109
+ eps = 1e-6
110
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
111
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
112
+
113
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
114
+ dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
115
+ pos_x = x_embed[:, :, :, None] / dim_tx
116
+
117
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
118
+ dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
119
+ pos_y = y_embed[:, :, :, None] / dim_ty
120
+
121
+ pos_x = torch.stack(
122
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
123
+ ).flatten(3)
124
+ pos_y = torch.stack(
125
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
126
+ ).flatten(3)
127
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
128
+
129
+ # import ipdb; ipdb.set_trace()
130
+
131
+ return pos
132
+
133
+
134
+ class PositionEmbeddingLearned(nn.Module):
135
+ """
136
+ Absolute pos embedding, learned.
137
+ """
138
+
139
+ def __init__(self, num_pos_feats=256):
140
+ super().__init__()
141
+ self.row_embed = nn.Embedding(50, num_pos_feats)
142
+ self.col_embed = nn.Embedding(50, num_pos_feats)
143
+ self.reset_parameters()
144
+
145
+ def reset_parameters(self):
146
+ nn.init.uniform_(self.row_embed.weight)
147
+ nn.init.uniform_(self.col_embed.weight)
148
+
149
+ def forward(self, tensor_list: NestedTensor):
150
+ x = tensor_list.tensors
151
+ h, w = x.shape[-2:]
152
+ i = torch.arange(w, device=x.device)
153
+ j = torch.arange(h, device=x.device)
154
+ x_emb = self.col_embed(i)
155
+ y_emb = self.row_embed(j)
156
+ pos = (
157
+ torch.cat(
158
+ [
159
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
160
+ y_emb.unsqueeze(1).repeat(1, w, 1),
161
+ ],
162
+ dim=-1,
163
+ )
164
+ .permute(2, 0, 1)
165
+ .unsqueeze(0)
166
+ .repeat(x.shape[0], 1, 1, 1)
167
+ )
168
+ return pos
169
+
170
+
171
+ def build_position_encoding(args):
172
+ N_steps = args.hidden_dim // 2
173
+ if args.position_embedding in ("v2", "sine"):
174
+ # TODO find a better way of exposing other arguments
175
+ position_embedding = PositionEmbeddingSineHW(
176
+ N_steps,
177
+ temperatureH=args.pe_temperatureH,
178
+ temperatureW=args.pe_temperatureW,
179
+ normalize=True,
180
+ )
181
+ elif args.position_embedding in ("v3", "learned"):
182
+ position_embedding = PositionEmbeddingLearned(N_steps)
183
+ else:
184
+ raise ValueError(f"not supported {args.position_embedding}")
185
+
186
+ return position_embedding
{models → GroundingDINO/groundingdino/models/GroundingDINO/backbone}/swin_transformer.py RENAMED
@@ -1,21 +1,32 @@
 
 
 
 
 
 
 
 
 
1
  # --------------------------------------------------------
2
- # Swin Transformer
3
- # Copyright (c) 2021 Microsoft
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Ze Liu
6
  # --------------------------------------------------------
7
 
8
  import numpy as np
9
- from scipy import interpolate
10
-
11
  import torch
12
  import torch.nn as nn
 
13
  import torch.utils.checkpoint as checkpoint
14
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15
 
 
 
16
 
17
  class Mlp(nn.Module):
18
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
 
 
 
 
19
  super().__init__()
20
  out_features = out_features or in_features
21
  hidden_features = hidden_features or in_features
@@ -38,7 +49,6 @@ def window_partition(x, window_size):
38
  Args:
39
  x: (B, H, W, C)
40
  window_size (int): window size
41
-
42
  Returns:
43
  windows: (num_windows*B, window_size, window_size, C)
44
  """
@@ -55,7 +65,6 @@ def window_reverse(windows, window_size, H, W):
55
  window_size (int): Window size
56
  H (int): Height of image
57
  W (int): Width of image
58
-
59
  Returns:
60
  x: (B, H, W, C)
61
  """
@@ -66,9 +75,8 @@ def window_reverse(windows, window_size, H, W):
66
 
67
 
68
  class WindowAttention(nn.Module):
69
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
70
  It supports both of shifted and non-shifted window.
71
-
72
  Args:
73
  dim (int): Number of input channels.
74
  window_size (tuple[int]): The height and width of the window.
@@ -79,18 +87,28 @@ class WindowAttention(nn.Module):
79
  proj_drop (float, optional): Dropout ratio of output. Default: 0.0
80
  """
81
 
82
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
 
 
 
 
 
 
 
 
 
83
 
84
  super().__init__()
85
  self.dim = dim
86
  self.window_size = window_size # Wh, Ww
87
  self.num_heads = num_heads
88
  head_dim = dim // num_heads
89
- self.scale = qk_scale or head_dim ** -0.5
90
 
91
  # define a parameter table of relative position bias
92
  self.relative_position_bias_table = nn.Parameter(
93
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
 
94
 
95
  # get pair-wise relative position index for each token inside the window
96
  coords_h = torch.arange(self.window_size[0])
@@ -110,25 +128,34 @@ class WindowAttention(nn.Module):
110
  self.proj = nn.Linear(dim, dim)
111
  self.proj_drop = nn.Dropout(proj_drop)
112
 
113
- trunc_normal_(self.relative_position_bias_table, std=.02)
114
  self.softmax = nn.Softmax(dim=-1)
115
 
116
  def forward(self, x, mask=None):
117
- """
118
  Args:
119
  x: input features with shape of (num_windows*B, N, C)
120
  mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
121
  """
122
  B_, N, C = x.shape
123
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
124
  q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
125
 
126
  q = q * self.scale
127
- attn = (q @ k.transpose(-2, -1))
128
-
129
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
130
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
131
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
 
 
 
 
 
132
  attn = attn + relative_position_bias.unsqueeze(0)
133
 
134
  if mask is not None:
@@ -146,29 +173,11 @@ class WindowAttention(nn.Module):
146
  x = self.proj_drop(x)
147
  return x
148
 
149
- def extra_repr(self) -> str:
150
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
151
-
152
- def flops(self, N):
153
- # calculate flops for 1 window with token length of N
154
- flops = 0
155
- # qkv = self.qkv(x)
156
- flops += N * self.dim * 3 * self.dim
157
- # attn = (q @ k.transpose(-2, -1))
158
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
159
- # x = (attn @ v)
160
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
161
- # x = self.proj(x)
162
- flops += N * self.dim * self.dim
163
- return flops
164
-
165
 
166
  class SwinTransformerBlock(nn.Module):
167
- r""" Swin Transformer Block.
168
-
169
  Args:
170
  dim (int): Number of input channels.
171
- input_resolution (tuple[int]): Input resulotion.
172
  num_heads (int): Number of attention heads.
173
  window_size (int): Window size.
174
  shift_size (int): Shift size for SW-MSA.
@@ -182,88 +191,104 @@ class SwinTransformerBlock(nn.Module):
182
  norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
183
  """
184
 
185
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
186
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
187
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
 
 
 
 
 
 
 
 
 
 
 
 
188
  super().__init__()
189
  self.dim = dim
190
- self.input_resolution = input_resolution
191
  self.num_heads = num_heads
192
  self.window_size = window_size
193
  self.shift_size = shift_size
194
  self.mlp_ratio = mlp_ratio
195
- if min(self.input_resolution) <= self.window_size:
196
- # if window size is larger than input resolution, we don't partition windows
197
- self.shift_size = 0
198
- self.window_size = min(self.input_resolution)
199
  assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
200
 
201
  self.norm1 = norm_layer(dim)
202
  self.attn = WindowAttention(
203
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
204
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
205
-
206
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
 
 
 
 
 
 
207
  self.norm2 = norm_layer(dim)
208
  mlp_hidden_dim = int(dim * mlp_ratio)
209
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
210
-
211
- if self.shift_size > 0:
212
- # calculate attention mask for SW-MSA
213
- H, W = self.input_resolution
214
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
215
- h_slices = (slice(0, -self.window_size),
216
- slice(-self.window_size, -self.shift_size),
217
- slice(-self.shift_size, None))
218
- w_slices = (slice(0, -self.window_size),
219
- slice(-self.window_size, -self.shift_size),
220
- slice(-self.shift_size, None))
221
- cnt = 0
222
- for h in h_slices:
223
- for w in w_slices:
224
- img_mask[:, h, w, :] = cnt
225
- cnt += 1
226
-
227
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
228
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
229
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
230
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
231
- else:
232
- attn_mask = None
233
 
234
- self.register_buffer("attn_mask", attn_mask)
 
235
 
236
- def forward(self, x):
237
- H, W = self.input_resolution
 
 
 
 
 
238
  B, L, C = x.shape
 
239
  assert L == H * W, "input feature has wrong size"
240
 
241
  shortcut = x
242
  x = self.norm1(x)
243
  x = x.view(B, H, W, C)
244
 
 
 
 
 
 
 
 
245
  # cyclic shift
246
  if self.shift_size > 0:
247
  shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
 
248
  else:
249
  shifted_x = x
 
250
 
251
  # partition windows
252
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
253
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
 
 
 
 
254
 
255
  # W-MSA/SW-MSA
256
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
257
 
258
  # merge windows
259
  attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
260
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
261
 
262
  # reverse cyclic shift
263
  if self.shift_size > 0:
264
  x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
265
  else:
266
  x = shifted_x
 
 
 
 
267
  x = x.view(B, H * W, C)
268
 
269
  # FFN
@@ -272,52 +297,36 @@ class SwinTransformerBlock(nn.Module):
272
 
273
  return x
274
 
275
- def extra_repr(self) -> str:
276
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
277
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
278
-
279
- def flops(self):
280
- flops = 0
281
- H, W = self.input_resolution
282
- # norm1
283
- flops += self.dim * H * W
284
- # W-MSA/SW-MSA
285
- nW = H * W / self.window_size / self.window_size
286
- flops += nW * self.attn.flops(self.window_size * self.window_size)
287
- # mlp
288
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
289
- # norm2
290
- flops += self.dim * H * W
291
- return flops
292
-
293
 
294
  class PatchMerging(nn.Module):
295
- r""" Patch Merging Layer.
296
-
297
  Args:
298
- input_resolution (tuple[int]): Resolution of input feature.
299
  dim (int): Number of input channels.
300
  norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
301
  """
302
 
303
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
304
  super().__init__()
305
- self.input_resolution = input_resolution
306
  self.dim = dim
307
  self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
308
  self.norm = norm_layer(4 * dim)
309
 
310
- def forward(self, x):
311
- """
312
- x: B, H*W, C
 
 
313
  """
314
- H, W = self.input_resolution
315
  B, L, C = x.shape
316
  assert L == H * W, "input feature has wrong size"
317
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
318
 
319
  x = x.view(B, H, W, C)
320
 
 
 
 
 
 
321
  x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
322
  x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
323
  x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
@@ -330,26 +339,15 @@ class PatchMerging(nn.Module):
330
 
331
  return x
332
 
333
- def extra_repr(self) -> str:
334
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
335
-
336
- def flops(self):
337
- H, W = self.input_resolution
338
- flops = H * W * self.dim
339
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
340
- return flops
341
-
342
 
343
  class BasicLayer(nn.Module):
344
- """ A basic Swin Transformer layer for one stage.
345
-
346
  Args:
347
- dim (int): Number of input channels.
348
- input_resolution (tuple[int]): Input resolution.
349
- depth (int): Number of blocks.
350
- num_heads (int): Number of attention heads.
351
- window_size (int): Local window size.
352
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
353
  qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
354
  qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
355
  drop (float, optional): Dropout rate. Default: 0.0
@@ -360,76 +358,117 @@ class BasicLayer(nn.Module):
360
  use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
361
  """
362
 
363
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
364
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
365
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
366
-
 
 
 
 
 
 
 
 
 
 
 
 
367
  super().__init__()
368
- self.dim = dim
369
- self.input_resolution = input_resolution
370
  self.depth = depth
371
  self.use_checkpoint = use_checkpoint
372
 
373
  # build blocks
374
- self.blocks = nn.ModuleList([
375
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
376
- num_heads=num_heads, window_size=window_size,
377
- shift_size=0 if (i % 2 == 0) else window_size // 2,
378
- mlp_ratio=mlp_ratio,
379
- qkv_bias=qkv_bias, qk_scale=qk_scale,
380
- drop=drop, attn_drop=attn_drop,
381
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
382
- norm_layer=norm_layer)
383
- for i in range(depth)])
 
 
 
 
 
 
 
 
384
 
385
  # patch merging layer
386
  if downsample is not None:
387
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
388
  else:
389
  self.downsample = None
390
 
391
- def forward(self, x):
392
- for blk in self.blocks:
393
- if self.use_checkpoint:
394
- x = checkpoint.checkpoint(blk, x)
395
- else:
396
- x = blk(x)
397
- if self.downsample is not None:
398
- x = self.downsample(x)
399
- return x
400
 
401
- def extra_repr(self) -> str:
402
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
- def flops(self):
405
- flops = 0
406
  for blk in self.blocks:
407
- flops += blk.flops()
 
 
 
 
408
  if self.downsample is not None:
409
- flops += self.downsample.flops()
410
- return flops
 
 
 
411
 
412
 
413
  class PatchEmbed(nn.Module):
414
- r""" Image to Patch Embedding
415
-
416
  Args:
417
- img_size (int): Image size. Default: 224.
418
  patch_size (int): Patch token size. Default: 4.
419
  in_chans (int): Number of input image channels. Default: 3.
420
  embed_dim (int): Number of linear projection output channels. Default: 96.
421
  norm_layer (nn.Module, optional): Normalization layer. Default: None
422
  """
423
 
424
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
425
  super().__init__()
426
- img_size = to_2tuple(img_size)
427
  patch_size = to_2tuple(patch_size)
428
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
429
- self.img_size = img_size
430
  self.patch_size = patch_size
431
- self.patches_resolution = patches_resolution
432
- self.num_patches = patches_resolution[0] * patches_resolution[1]
433
 
434
  self.in_chans = in_chans
435
  self.embed_dim = embed_dim
@@ -441,214 +480,323 @@ class PatchEmbed(nn.Module):
441
  self.norm = None
442
 
443
  def forward(self, x):
444
- B, C, H, W = x.shape
445
- # FIXME look at relaxing size constraints
446
- assert H == self.img_size[0] and W == self.img_size[1], \
447
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
448
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
 
 
 
 
449
  if self.norm is not None:
 
 
450
  x = self.norm(x)
451
- return x
452
 
453
- def flops(self):
454
- Ho, Wo = self.patches_resolution
455
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
456
- if self.norm is not None:
457
- flops += Ho * Wo * self.embed_dim
458
- return flops
459
 
460
 
461
  class SwinTransformer(nn.Module):
462
- r""" Swin Transformer
463
  A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
464
  https://arxiv.org/pdf/2103.14030
465
-
466
  Args:
467
- img_size (int | tuple(int)): Input image size. Default 224
468
- patch_size (int | tuple(int)): Patch size. Default: 4
469
- in_chans (int): Number of input image channels. Default: 3
470
- num_classes (int): Number of classes for classification head. Default: 1000
471
- embed_dim (int): Patch embedding dimension. Default: 96
472
- depths (tuple(int)): Depth of each Swin Transformer layer.
473
- num_heads (tuple(int)): Number of attention heads in different layers.
474
- window_size (int): Window size. Default: 7
475
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
476
  qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
477
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
478
- drop_rate (float): Dropout rate. Default: 0
479
- attn_drop_rate (float): Attention dropout rate. Default: 0
480
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
481
  norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
482
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
483
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
484
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
 
 
 
 
485
  """
486
 
487
- def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
488
- embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
489
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
490
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
491
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
492
- use_checkpoint=False, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  super().__init__()
494
 
495
- self.num_classes = num_classes
496
  self.num_layers = len(depths)
497
  self.embed_dim = embed_dim
498
  self.ape = ape
499
  self.patch_norm = patch_norm
500
- self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
501
- self.mlp_ratio = mlp_ratio
 
 
 
 
502
 
503
  # split image into non-overlapping patches
504
  self.patch_embed = PatchEmbed(
505
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
506
- norm_layer=norm_layer if self.patch_norm else None)
507
- num_patches = self.patch_embed.num_patches
508
- patches_resolution = self.patch_embed.patches_resolution
509
- self.patches_resolution = patches_resolution
510
 
511
  # absolute position embedding
512
  if self.ape:
513
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
514
- trunc_normal_(self.absolute_pos_embed, std=.02)
 
 
 
 
 
 
 
 
 
515
 
516
  self.pos_drop = nn.Dropout(p=drop_rate)
517
 
518
  # stochastic depth
519
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
 
 
520
 
521
  # build layers
522
  self.layers = nn.ModuleList()
 
 
 
 
 
 
 
523
  for i_layer in range(self.num_layers):
524
- layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
525
- input_resolution=(patches_resolution[0] // (2 ** i_layer),
526
- patches_resolution[1] // (2 ** i_layer)),
527
- depth=depths[i_layer],
528
- num_heads=num_heads[i_layer],
529
- window_size=window_size,
530
- mlp_ratio=self.mlp_ratio,
531
- qkv_bias=qkv_bias, qk_scale=qk_scale,
532
- drop=drop_rate, attn_drop=attn_drop_rate,
533
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
534
- norm_layer=norm_layer,
535
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
536
- use_checkpoint=use_checkpoint)
 
 
 
 
537
  self.layers.append(layer)
538
 
539
- self.norm = norm_layer(self.num_features)
540
- self.avgpool = nn.AdaptiveAvgPool1d(1)
541
- # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
542
-
543
- self.apply(self._init_weights)
544
-
545
- def _init_weights(self, m):
546
- if isinstance(m, nn.Linear):
547
- trunc_normal_(m.weight, std=.02)
548
- if isinstance(m, nn.Linear) and m.bias is not None:
549
- nn.init.constant_(m.bias, 0)
550
- elif isinstance(m, nn.LayerNorm):
551
- nn.init.constant_(m.bias, 0)
552
- nn.init.constant_(m.weight, 1.0)
553
-
554
- @torch.jit.ignore
555
- def no_weight_decay(self):
556
- return {'absolute_pos_embed'}
557
-
558
- @torch.jit.ignore
559
- def no_weight_decay_keywords(self):
560
- return {'relative_position_bias_table'}
561
-
562
- def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  x = self.patch_embed(x)
 
 
564
  if self.ape:
565
- x = x + self.absolute_pos_embed
 
 
 
 
 
 
566
  x = self.pos_drop(x)
567
 
568
- for layer in self.layers:
569
- x = layer(x)
570
-
571
- x = self.norm(x) # B L C
572
-
573
- x_cls = self.avgpool(x.transpose(1, 2)) # B C 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
- if idx_to_group_img is None:
576
- return torch.cat([x_cls.transpose(1, 2), x], dim=1)
 
 
 
 
 
577
  else:
578
- x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2]))
579
- weights = image_atts[:, 1:].unsqueeze(2) # B L 1
580
- x_bs_cls = torch.sum((weights * x_bs).transpose(1, 2), dim=-1, keepdim=True) # B C 1
581
- x_bs_cls = x_bs_cls / torch.sum(weights.transpose(1, 2), dim=-1, keepdim=True) # avgpool
582
-
583
- return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), \
584
- torch.cat([x_cls.transpose(1, 2), x], dim=1)
585
-
586
- def flops(self):
587
- flops = 0
588
- flops += self.patch_embed.flops()
589
- for i, layer in enumerate(self.layers):
590
- flops += layer.flops()
591
- flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
592
- flops += self.num_features * self.num_classes
593
- return flops
594
-
595
-
596
- def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''):
597
- # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348
598
-
599
- # rel_pos_bias: relative_position_bias_table
600
- src_num_pos, num_attn_heads = rel_pos_bias.size()
601
-
602
- num_extra_tokens = 0
603
- src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
604
- dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
605
- if src_size != dst_size:
606
- print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size))
607
-
608
- # extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
609
- # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
610
-
611
- def geometric_progression(a, r, n):
612
- return a * (1.0 - r ** n) / (1.0 - r)
613
-
614
- left, right = 1.01, 1.5
615
- while right - left > 1e-6:
616
- q = (left + right) / 2.0
617
- gp = geometric_progression(1, q, src_size // 2)
618
- if gp > dst_size // 2:
619
- right = q
620
- else:
621
- left = q
622
-
623
- # if q > 1.090307:
624
- # q = 1.090307
625
-
626
- dis = []
627
- cur = 1
628
- for i in range(src_size // 2):
629
- dis.append(cur)
630
- cur += q ** (i + 1)
631
-
632
- r_ids = [-_ for _ in reversed(dis)]
633
-
634
- x = r_ids + [0] + dis
635
- y = r_ids + [0] + dis
636
-
637
- t = dst_size // 2.0
638
- dx = np.arange(-t, t + 0.1, 1.0)
639
- dy = np.arange(-t, t + 0.1, 1.0)
640
-
641
- # print("Original positions = %s" % str(x))
642
- # print("Target positions = %s" % str(dx))
643
-
644
- all_rel_pos_bias = []
645
-
646
- for i in range(num_attn_heads):
647
- z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
648
- f = interpolate.interp2d(x, y, z, kind='cubic')
649
- all_rel_pos_bias.append(
650
- torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
651
-
652
- rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
653
 
654
- return rel_pos_bias
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
  # --------------------------------------------------------
11
+ # modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
 
 
 
12
  # --------------------------------------------------------
13
 
14
  import numpy as np
 
 
15
  import torch
16
  import torch.nn as nn
17
+ import torch.nn.functional as F
18
  import torch.utils.checkpoint as checkpoint
19
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
20
 
21
+ from groundingdino.util.misc import NestedTensor
22
+
23
 
24
  class Mlp(nn.Module):
25
+ """Multilayer perceptron."""
26
+
27
+ def __init__(
28
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
29
+ ):
30
  super().__init__()
31
  out_features = out_features or in_features
32
  hidden_features = hidden_features or in_features
 
49
  Args:
50
  x: (B, H, W, C)
51
  window_size (int): window size
 
52
  Returns:
53
  windows: (num_windows*B, window_size, window_size, C)
54
  """
 
65
  window_size (int): Window size
66
  H (int): Height of image
67
  W (int): Width of image
 
68
  Returns:
69
  x: (B, H, W, C)
70
  """
 
75
 
76
 
77
  class WindowAttention(nn.Module):
78
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
79
  It supports both of shifted and non-shifted window.
 
80
  Args:
81
  dim (int): Number of input channels.
82
  window_size (tuple[int]): The height and width of the window.
 
87
  proj_drop (float, optional): Dropout ratio of output. Default: 0.0
88
  """
89
 
90
+ def __init__(
91
+ self,
92
+ dim,
93
+ window_size,
94
+ num_heads,
95
+ qkv_bias=True,
96
+ qk_scale=None,
97
+ attn_drop=0.0,
98
+ proj_drop=0.0,
99
+ ):
100
 
101
  super().__init__()
102
  self.dim = dim
103
  self.window_size = window_size # Wh, Ww
104
  self.num_heads = num_heads
105
  head_dim = dim // num_heads
106
+ self.scale = qk_scale or head_dim**-0.5
107
 
108
  # define a parameter table of relative position bias
109
  self.relative_position_bias_table = nn.Parameter(
110
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
111
+ ) # 2*Wh-1 * 2*Ww-1, nH
112
 
113
  # get pair-wise relative position index for each token inside the window
114
  coords_h = torch.arange(self.window_size[0])
 
128
  self.proj = nn.Linear(dim, dim)
129
  self.proj_drop = nn.Dropout(proj_drop)
130
 
131
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
132
  self.softmax = nn.Softmax(dim=-1)
133
 
134
  def forward(self, x, mask=None):
135
+ """Forward function.
136
  Args:
137
  x: input features with shape of (num_windows*B, N, C)
138
  mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
139
  """
140
  B_, N, C = x.shape
141
+ qkv = (
142
+ self.qkv(x)
143
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
144
+ .permute(2, 0, 3, 1, 4)
145
+ )
146
  q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
147
 
148
  q = q * self.scale
149
+ attn = q @ k.transpose(-2, -1)
150
+
151
+ relative_position_bias = self.relative_position_bias_table[
152
+ self.relative_position_index.view(-1)
153
+ ].view(
154
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
155
+ ) # Wh*Ww,Wh*Ww,nH
156
+ relative_position_bias = relative_position_bias.permute(
157
+ 2, 0, 1
158
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
159
  attn = attn + relative_position_bias.unsqueeze(0)
160
 
161
  if mask is not None:
 
173
  x = self.proj_drop(x)
174
  return x
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  class SwinTransformerBlock(nn.Module):
178
+ """Swin Transformer Block.
 
179
  Args:
180
  dim (int): Number of input channels.
 
181
  num_heads (int): Number of attention heads.
182
  window_size (int): Window size.
183
  shift_size (int): Shift size for SW-MSA.
 
191
  norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
192
  """
193
 
194
+ def __init__(
195
+ self,
196
+ dim,
197
+ num_heads,
198
+ window_size=7,
199
+ shift_size=0,
200
+ mlp_ratio=4.0,
201
+ qkv_bias=True,
202
+ qk_scale=None,
203
+ drop=0.0,
204
+ attn_drop=0.0,
205
+ drop_path=0.0,
206
+ act_layer=nn.GELU,
207
+ norm_layer=nn.LayerNorm,
208
+ ):
209
  super().__init__()
210
  self.dim = dim
 
211
  self.num_heads = num_heads
212
  self.window_size = window_size
213
  self.shift_size = shift_size
214
  self.mlp_ratio = mlp_ratio
 
 
 
 
215
  assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
216
 
217
  self.norm1 = norm_layer(dim)
218
  self.attn = WindowAttention(
219
+ dim,
220
+ window_size=to_2tuple(self.window_size),
221
+ num_heads=num_heads,
222
+ qkv_bias=qkv_bias,
223
+ qk_scale=qk_scale,
224
+ attn_drop=attn_drop,
225
+ proj_drop=drop,
226
+ )
227
+
228
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
229
  self.norm2 = norm_layer(dim)
230
  mlp_hidden_dim = int(dim * mlp_ratio)
231
+ self.mlp = Mlp(
232
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
233
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ self.H = None
236
+ self.W = None
237
 
238
+ def forward(self, x, mask_matrix):
239
+ """Forward function.
240
+ Args:
241
+ x: Input feature, tensor size (B, H*W, C).
242
+ H, W: Spatial resolution of the input feature.
243
+ mask_matrix: Attention mask for cyclic shift.
244
+ """
245
  B, L, C = x.shape
246
+ H, W = self.H, self.W
247
  assert L == H * W, "input feature has wrong size"
248
 
249
  shortcut = x
250
  x = self.norm1(x)
251
  x = x.view(B, H, W, C)
252
 
253
+ # pad feature maps to multiples of window size
254
+ pad_l = pad_t = 0
255
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
256
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
257
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
258
+ _, Hp, Wp, _ = x.shape
259
+
260
  # cyclic shift
261
  if self.shift_size > 0:
262
  shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
263
+ attn_mask = mask_matrix
264
  else:
265
  shifted_x = x
266
+ attn_mask = None
267
 
268
  # partition windows
269
+ x_windows = window_partition(
270
+ shifted_x, self.window_size
271
+ ) # nW*B, window_size, window_size, C
272
+ x_windows = x_windows.view(
273
+ -1, self.window_size * self.window_size, C
274
+ ) # nW*B, window_size*window_size, C
275
 
276
  # W-MSA/SW-MSA
277
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
278
 
279
  # merge windows
280
  attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
281
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
282
 
283
  # reverse cyclic shift
284
  if self.shift_size > 0:
285
  x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
286
  else:
287
  x = shifted_x
288
+
289
+ if pad_r > 0 or pad_b > 0:
290
+ x = x[:, :H, :W, :].contiguous()
291
+
292
  x = x.view(B, H * W, C)
293
 
294
  # FFN
 
297
 
298
  return x
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  class PatchMerging(nn.Module):
302
+ """Patch Merging Layer
 
303
  Args:
 
304
  dim (int): Number of input channels.
305
  norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
306
  """
307
 
308
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
309
  super().__init__()
 
310
  self.dim = dim
311
  self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
312
  self.norm = norm_layer(4 * dim)
313
 
314
+ def forward(self, x, H, W):
315
+ """Forward function.
316
+ Args:
317
+ x: Input feature, tensor size (B, H*W, C).
318
+ H, W: Spatial resolution of the input feature.
319
  """
 
320
  B, L, C = x.shape
321
  assert L == H * W, "input feature has wrong size"
 
322
 
323
  x = x.view(B, H, W, C)
324
 
325
+ # padding
326
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
327
+ if pad_input:
328
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
329
+
330
  x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
331
  x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
332
  x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
 
339
 
340
  return x
341
 
 
 
 
 
 
 
 
 
 
342
 
343
  class BasicLayer(nn.Module):
344
+ """A basic Swin Transformer layer for one stage.
 
345
  Args:
346
+ dim (int): Number of feature channels
347
+ depth (int): Depths of this stage.
348
+ num_heads (int): Number of attention head.
349
+ window_size (int): Local window size. Default: 7.
350
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
 
351
  qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
352
  qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
353
  drop (float, optional): Dropout rate. Default: 0.0
 
358
  use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
359
  """
360
 
361
+ def __init__(
362
+ self,
363
+ dim,
364
+ depth,
365
+ num_heads,
366
+ window_size=7,
367
+ mlp_ratio=4.0,
368
+ qkv_bias=True,
369
+ qk_scale=None,
370
+ drop=0.0,
371
+ attn_drop=0.0,
372
+ drop_path=0.0,
373
+ norm_layer=nn.LayerNorm,
374
+ downsample=None,
375
+ use_checkpoint=False,
376
+ ):
377
  super().__init__()
378
+ self.window_size = window_size
379
+ self.shift_size = window_size // 2
380
  self.depth = depth
381
  self.use_checkpoint = use_checkpoint
382
 
383
  # build blocks
384
+ self.blocks = nn.ModuleList(
385
+ [
386
+ SwinTransformerBlock(
387
+ dim=dim,
388
+ num_heads=num_heads,
389
+ window_size=window_size,
390
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
391
+ mlp_ratio=mlp_ratio,
392
+ qkv_bias=qkv_bias,
393
+ qk_scale=qk_scale,
394
+ drop=drop,
395
+ attn_drop=attn_drop,
396
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
397
+ norm_layer=norm_layer,
398
+ )
399
+ for i in range(depth)
400
+ ]
401
+ )
402
 
403
  # patch merging layer
404
  if downsample is not None:
405
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
406
  else:
407
  self.downsample = None
408
 
409
+ def forward(self, x, H, W):
410
+ """Forward function.
411
+ Args:
412
+ x: Input feature, tensor size (B, H*W, C).
413
+ H, W: Spatial resolution of the input feature.
414
+ """
 
 
 
415
 
416
+ # calculate attention mask for SW-MSA
417
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
418
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
419
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
420
+ h_slices = (
421
+ slice(0, -self.window_size),
422
+ slice(-self.window_size, -self.shift_size),
423
+ slice(-self.shift_size, None),
424
+ )
425
+ w_slices = (
426
+ slice(0, -self.window_size),
427
+ slice(-self.window_size, -self.shift_size),
428
+ slice(-self.shift_size, None),
429
+ )
430
+ cnt = 0
431
+ for h in h_slices:
432
+ for w in w_slices:
433
+ img_mask[:, h, w, :] = cnt
434
+ cnt += 1
435
+
436
+ mask_windows = window_partition(
437
+ img_mask, self.window_size
438
+ ) # nW, window_size, window_size, 1
439
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
440
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
441
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
442
+ attn_mask == 0, float(0.0)
443
+ )
444
 
 
 
445
  for blk in self.blocks:
446
+ blk.H, blk.W = H, W
447
+ if self.use_checkpoint:
448
+ x = checkpoint.checkpoint(blk, x, attn_mask)
449
+ else:
450
+ x = blk(x, attn_mask)
451
  if self.downsample is not None:
452
+ x_down = self.downsample(x, H, W)
453
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
454
+ return x, H, W, x_down, Wh, Ww
455
+ else:
456
+ return x, H, W, x, H, W
457
 
458
 
459
  class PatchEmbed(nn.Module):
460
+ """Image to Patch Embedding
 
461
  Args:
 
462
  patch_size (int): Patch token size. Default: 4.
463
  in_chans (int): Number of input image channels. Default: 3.
464
  embed_dim (int): Number of linear projection output channels. Default: 96.
465
  norm_layer (nn.Module, optional): Normalization layer. Default: None
466
  """
467
 
468
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
469
  super().__init__()
 
470
  patch_size = to_2tuple(patch_size)
 
 
471
  self.patch_size = patch_size
 
 
472
 
473
  self.in_chans = in_chans
474
  self.embed_dim = embed_dim
 
480
  self.norm = None
481
 
482
  def forward(self, x):
483
+ """Forward function."""
484
+ # padding
485
+ _, _, H, W = x.size()
486
+ if W % self.patch_size[1] != 0:
487
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
488
+ if H % self.patch_size[0] != 0:
489
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
490
+
491
+ x = self.proj(x) # B C Wh Ww
492
  if self.norm is not None:
493
+ Wh, Ww = x.size(2), x.size(3)
494
+ x = x.flatten(2).transpose(1, 2)
495
  x = self.norm(x)
496
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
497
 
498
+ return x
 
 
 
 
 
499
 
500
 
501
  class SwinTransformer(nn.Module):
502
+ """Swin Transformer backbone.
503
  A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
504
  https://arxiv.org/pdf/2103.14030
 
505
  Args:
506
+ pretrain_img_size (int): Input image size for training the pretrained model,
507
+ used in absolute postion embedding. Default 224.
508
+ patch_size (int | tuple(int)): Patch size. Default: 4.
509
+ in_chans (int): Number of input image channels. Default: 3.
510
+ embed_dim (int): Number of linear projection output channels. Default: 96.
511
+ depths (tuple[int]): Depths of each Swin Transformer stage.
512
+ num_heads (tuple[int]): Number of attention head of each stage.
513
+ window_size (int): Window size. Default: 7.
514
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
515
  qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
516
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
517
+ drop_rate (float): Dropout rate.
518
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
519
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
520
  norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
521
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
522
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
523
+ out_indices (Sequence[int]): Output from which stages.
524
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
525
+ -1 means not freezing any parameters.
526
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
527
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
528
  """
529
 
530
+ def __init__(
531
+ self,
532
+ pretrain_img_size=224,
533
+ patch_size=4,
534
+ in_chans=3,
535
+ embed_dim=96,
536
+ depths=[2, 2, 6, 2],
537
+ num_heads=[3, 6, 12, 24],
538
+ window_size=7,
539
+ mlp_ratio=4.0,
540
+ qkv_bias=True,
541
+ qk_scale=None,
542
+ drop_rate=0.0,
543
+ attn_drop_rate=0.0,
544
+ drop_path_rate=0.2,
545
+ norm_layer=nn.LayerNorm,
546
+ ape=False,
547
+ patch_norm=True,
548
+ out_indices=(0, 1, 2, 3),
549
+ frozen_stages=-1,
550
+ dilation=False,
551
+ use_checkpoint=False,
552
+ ):
553
  super().__init__()
554
 
555
+ self.pretrain_img_size = pretrain_img_size
556
  self.num_layers = len(depths)
557
  self.embed_dim = embed_dim
558
  self.ape = ape
559
  self.patch_norm = patch_norm
560
+ self.out_indices = out_indices
561
+ self.frozen_stages = frozen_stages
562
+ self.dilation = dilation
563
+
564
+ # if use_checkpoint:
565
+ # print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
566
 
567
  # split image into non-overlapping patches
568
  self.patch_embed = PatchEmbed(
569
+ patch_size=patch_size,
570
+ in_chans=in_chans,
571
+ embed_dim=embed_dim,
572
+ norm_layer=norm_layer if self.patch_norm else None,
573
+ )
574
 
575
  # absolute position embedding
576
  if self.ape:
577
+ pretrain_img_size = to_2tuple(pretrain_img_size)
578
+ patch_size = to_2tuple(patch_size)
579
+ patches_resolution = [
580
+ pretrain_img_size[0] // patch_size[0],
581
+ pretrain_img_size[1] // patch_size[1],
582
+ ]
583
+
584
+ self.absolute_pos_embed = nn.Parameter(
585
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
586
+ )
587
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
588
 
589
  self.pos_drop = nn.Dropout(p=drop_rate)
590
 
591
  # stochastic depth
592
+ dpr = [
593
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
594
+ ] # stochastic depth decay rule
595
 
596
  # build layers
597
  self.layers = nn.ModuleList()
598
+ # prepare downsample list
599
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
600
+ downsamplelist[-1] = None
601
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
602
+ if self.dilation:
603
+ downsamplelist[-2] = None
604
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
605
  for i_layer in range(self.num_layers):
606
+ layer = BasicLayer(
607
+ # dim=int(embed_dim * 2 ** i_layer),
608
+ dim=num_features[i_layer],
609
+ depth=depths[i_layer],
610
+ num_heads=num_heads[i_layer],
611
+ window_size=window_size,
612
+ mlp_ratio=mlp_ratio,
613
+ qkv_bias=qkv_bias,
614
+ qk_scale=qk_scale,
615
+ drop=drop_rate,
616
+ attn_drop=attn_drop_rate,
617
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
618
+ norm_layer=norm_layer,
619
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
620
+ downsample=downsamplelist[i_layer],
621
+ use_checkpoint=use_checkpoint,
622
+ )
623
  self.layers.append(layer)
624
 
625
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
626
+ self.num_features = num_features
627
+
628
+ # add a norm layer for each output
629
+ for i_layer in out_indices:
630
+ layer = norm_layer(num_features[i_layer])
631
+ layer_name = f"norm{i_layer}"
632
+ self.add_module(layer_name, layer)
633
+
634
+ self._freeze_stages()
635
+
636
+ def _freeze_stages(self):
637
+ if self.frozen_stages >= 0:
638
+ self.patch_embed.eval()
639
+ for param in self.patch_embed.parameters():
640
+ param.requires_grad = False
641
+
642
+ if self.frozen_stages >= 1 and self.ape:
643
+ self.absolute_pos_embed.requires_grad = False
644
+
645
+ if self.frozen_stages >= 2:
646
+ self.pos_drop.eval()
647
+ for i in range(0, self.frozen_stages - 1):
648
+ m = self.layers[i]
649
+ m.eval()
650
+ for param in m.parameters():
651
+ param.requires_grad = False
652
+
653
+ # def init_weights(self, pretrained=None):
654
+ # """Initialize the weights in backbone.
655
+ # Args:
656
+ # pretrained (str, optional): Path to pre-trained weights.
657
+ # Defaults to None.
658
+ # """
659
+
660
+ # def _init_weights(m):
661
+ # if isinstance(m, nn.Linear):
662
+ # trunc_normal_(m.weight, std=.02)
663
+ # if isinstance(m, nn.Linear) and m.bias is not None:
664
+ # nn.init.constant_(m.bias, 0)
665
+ # elif isinstance(m, nn.LayerNorm):
666
+ # nn.init.constant_(m.bias, 0)
667
+ # nn.init.constant_(m.weight, 1.0)
668
+
669
+ # if isinstance(pretrained, str):
670
+ # self.apply(_init_weights)
671
+ # logger = get_root_logger()
672
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
673
+ # elif pretrained is None:
674
+ # self.apply(_init_weights)
675
+ # else:
676
+ # raise TypeError('pretrained must be a str or None')
677
+
678
+ def forward_raw(self, x):
679
+ """Forward function."""
680
  x = self.patch_embed(x)
681
+
682
+ Wh, Ww = x.size(2), x.size(3)
683
  if self.ape:
684
+ # interpolate the position embedding to the corresponding size
685
+ absolute_pos_embed = F.interpolate(
686
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
687
+ )
688
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
689
+ else:
690
+ x = x.flatten(2).transpose(1, 2)
691
  x = self.pos_drop(x)
692
 
693
+ outs = []
694
+ for i in range(self.num_layers):
695
+ layer = self.layers[i]
696
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
697
+ # import ipdb; ipdb.set_trace()
698
+
699
+ if i in self.out_indices:
700
+ norm_layer = getattr(self, f"norm{i}")
701
+ x_out = norm_layer(x_out)
702
+
703
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
704
+ outs.append(out)
705
+ # in:
706
+ # torch.Size([2, 3, 1024, 1024])
707
+ # outs:
708
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
709
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
710
+ return tuple(outs)
711
+
712
+ def forward(self, tensor_list: NestedTensor):
713
+ x = tensor_list.tensors
714
+
715
+ """Forward function."""
716
+ x = self.patch_embed(x)
717
 
718
+ Wh, Ww = x.size(2), x.size(3)
719
+ if self.ape:
720
+ # interpolate the position embedding to the corresponding size
721
+ absolute_pos_embed = F.interpolate(
722
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
723
+ )
724
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
725
  else:
726
+ x = x.flatten(2).transpose(1, 2)
727
+ x = self.pos_drop(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
+ outs = []
730
+ for i in range(self.num_layers):
731
+ layer = self.layers[i]
732
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
733
+
734
+ if i in self.out_indices:
735
+ norm_layer = getattr(self, f"norm{i}")
736
+ x_out = norm_layer(x_out)
737
+
738
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
739
+ outs.append(out)
740
+ # in:
741
+ # torch.Size([2, 3, 1024, 1024])
742
+ # out:
743
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
744
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
745
+
746
+ # collect for nesttensors
747
+ outs_dict = {}
748
+ for idx, out_i in enumerate(outs):
749
+ m = tensor_list.mask
750
+ assert m is not None
751
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
752
+ outs_dict[idx] = NestedTensor(out_i, mask)
753
+
754
+ return outs_dict
755
+
756
+ def train(self, mode=True):
757
+ """Convert the model into training mode while keep layers freezed."""
758
+ super(SwinTransformer, self).train(mode)
759
+ self._freeze_stages()
760
+
761
+
762
+ def build_swin_transformer(modelname, pretrain_img_size, **kw):
763
+ assert modelname in [
764
+ "swin_T_224_1k",
765
+ "swin_B_224_22k",
766
+ "swin_B_384_22k",
767
+ "swin_L_224_22k",
768
+ "swin_L_384_22k",
769
+ ]
770
+
771
+ model_para_dict = {
772
+ "swin_T_224_1k": dict(
773
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
774
+ ),
775
+ "swin_B_224_22k": dict(
776
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
777
+ ),
778
+ "swin_B_384_22k": dict(
779
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12
780
+ ),
781
+ "swin_L_224_22k": dict(
782
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=7
783
+ ),
784
+ "swin_L_384_22k": dict(
785
+ embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12
786
+ ),
787
+ }
788
+ kw_cgf = model_para_dict[modelname]
789
+ kw_cgf.update(kw)
790
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
791
+ return model
792
+
793
+
794
+ if __name__ == "__main__":
795
+ model = build_swin_transformer("swin_L_384_22k", 384, dilation=True)
796
+ x = torch.rand(2, 3, 1024, 1024)
797
+ y = model.forward_raw(x)
798
+ import ipdb
799
+
800
+ ipdb.set_trace()
801
+ x = torch.rand(2, 3, 384, 384)
802
+ y = model.forward_raw(x)
GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from torch import Tensor, nn
12
+ from torchvision.ops.boxes import nms
13
+ from transformers import BertConfig, BertModel, BertPreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
15
+
16
+
17
+ class BertModelWarper(nn.Module):
18
+ def __init__(self, bert_model):
19
+ super().__init__()
20
+ # self.bert = bert_modelc
21
+
22
+ self.config = bert_model.config
23
+ self.embeddings = bert_model.embeddings
24
+ self.encoder = bert_model.encoder
25
+ self.pooler = bert_model.pooler
26
+
27
+ self.get_extended_attention_mask = bert_model.get_extended_attention_mask
28
+ self.invert_attention_mask = bert_model.invert_attention_mask
29
+ self.get_head_mask = bert_model.get_head_mask
30
+
31
+ def forward(
32
+ self,
33
+ input_ids=None,
34
+ attention_mask=None,
35
+ token_type_ids=None,
36
+ position_ids=None,
37
+ head_mask=None,
38
+ inputs_embeds=None,
39
+ encoder_hidden_states=None,
40
+ encoder_attention_mask=None,
41
+ past_key_values=None,
42
+ use_cache=None,
43
+ output_attentions=None,
44
+ output_hidden_states=None,
45
+ return_dict=None,
46
+ ):
47
+ r"""
48
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
49
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
50
+ the model is configured as a decoder.
51
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
52
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
53
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
54
+
55
+ - 1 for tokens that are **not masked**,
56
+ - 0 for tokens that are **masked**.
57
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
58
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
59
+
60
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
61
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
62
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
63
+ use_cache (:obj:`bool`, `optional`):
64
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
65
+ decoding (see :obj:`past_key_values`).
66
+ """
67
+ output_attentions = (
68
+ output_attentions if output_attentions is not None else self.config.output_attentions
69
+ )
70
+ output_hidden_states = (
71
+ output_hidden_states
72
+ if output_hidden_states is not None
73
+ else self.config.output_hidden_states
74
+ )
75
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
+
77
+ if self.config.is_decoder:
78
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
79
+ else:
80
+ use_cache = False
81
+
82
+ if input_ids is not None and inputs_embeds is not None:
83
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
84
+ elif input_ids is not None:
85
+ input_shape = input_ids.size()
86
+ batch_size, seq_length = input_shape
87
+ elif inputs_embeds is not None:
88
+ input_shape = inputs_embeds.size()[:-1]
89
+ batch_size, seq_length = input_shape
90
+ else:
91
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
92
+
93
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
94
+
95
+ # past_key_values_length
96
+ past_key_values_length = (
97
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
98
+ )
99
+
100
+ if attention_mask is None:
101
+ attention_mask = torch.ones(
102
+ ((batch_size, seq_length + past_key_values_length)), device=device
103
+ )
104
+ if token_type_ids is None:
105
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
106
+
107
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
108
+ # ourselves in which case we just need to make it broadcastable to all heads.
109
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
110
+ attention_mask, input_shape, device
111
+ )
112
+
113
+ # If a 2D or 3D attention mask is provided for the cross-attention
114
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
115
+ if self.config.is_decoder and encoder_hidden_states is not None:
116
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
117
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
118
+ if encoder_attention_mask is None:
119
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
120
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
121
+ else:
122
+ encoder_extended_attention_mask = None
123
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
124
+ # import ipdb; ipdb.set_trace()
125
+
126
+ # Prepare head mask if needed
127
+ # 1.0 in head_mask indicate we keep the head
128
+ # attention_probs has shape bsz x n_heads x N x N
129
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
130
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
131
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
132
+
133
+ embedding_output = self.embeddings(
134
+ input_ids=input_ids,
135
+ position_ids=position_ids,
136
+ token_type_ids=token_type_ids,
137
+ inputs_embeds=inputs_embeds,
138
+ past_key_values_length=past_key_values_length,
139
+ )
140
+
141
+ encoder_outputs = self.encoder(
142
+ embedding_output,
143
+ attention_mask=extended_attention_mask,
144
+ head_mask=head_mask,
145
+ encoder_hidden_states=encoder_hidden_states,
146
+ encoder_attention_mask=encoder_extended_attention_mask,
147
+ past_key_values=past_key_values,
148
+ use_cache=use_cache,
149
+ output_attentions=output_attentions,
150
+ output_hidden_states=output_hidden_states,
151
+ return_dict=return_dict,
152
+ )
153
+ sequence_output = encoder_outputs[0]
154
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
155
+
156
+ if not return_dict:
157
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
158
+
159
+ return BaseModelOutputWithPoolingAndCrossAttentions(
160
+ last_hidden_state=sequence_output,
161
+ pooler_output=pooled_output,
162
+ past_key_values=encoder_outputs.past_key_values,
163
+ hidden_states=encoder_outputs.hidden_states,
164
+ attentions=encoder_outputs.attentions,
165
+ cross_attentions=encoder_outputs.cross_attentions,
166
+ )
167
+
168
+
169
+ class TextEncoderShell(nn.Module):
170
+ def __init__(self, text_encoder):
171
+ super().__init__()
172
+ self.text_encoder = text_encoder
173
+ self.config = self.text_encoder.config
174
+
175
+ def forward(self, **kw):
176
+ # feed into text encoder
177
+ return self.text_encoder(**kw)
178
+
179
+
180
+ def generate_masks_with_special_tokens(tokenized, special_tokens_list, tokenizer):
181
+ """Generate attention mask between each pair of special tokens
182
+ Args:
183
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
184
+ special_tokens_mask (list): special tokens mask.
185
+ Returns:
186
+ torch.Tensor: attention mask between each special tokens.
187
+ """
188
+ input_ids = tokenized["input_ids"]
189
+ bs, num_token = input_ids.shape
190
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
191
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
192
+ for special_token in special_tokens_list:
193
+ special_tokens_mask |= input_ids == special_token
194
+
195
+ # idxs: each row is a list of indices of special tokens
196
+ idxs = torch.nonzero(special_tokens_mask)
197
+
198
+ # generate attention mask and positional ids
199
+ attention_mask = (
200
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
201
+ )
202
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
203
+ previous_col = 0
204
+ for i in range(idxs.shape[0]):
205
+ row, col = idxs[i]
206
+ if (col == 0) or (col == num_token - 1):
207
+ attention_mask[row, col, col] = True
208
+ position_ids[row, col] = 0
209
+ else:
210
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
211
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
212
+ 0, col - previous_col, device=input_ids.device
213
+ )
214
+
215
+ previous_col = col
216
+
217
+ # # padding mask
218
+ # padding_mask = tokenized['attention_mask']
219
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
220
+
221
+ return attention_mask, position_ids.to(torch.long)
222
+
223
+
224
+ def generate_masks_with_special_tokens_and_transfer_map(tokenized, special_tokens_list, tokenizer):
225
+ """Generate attention mask between each pair of special tokens
226
+ Args:
227
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
228
+ special_tokens_mask (list): special tokens mask.
229
+ Returns:
230
+ torch.Tensor: attention mask between each special tokens.
231
+ """
232
+ input_ids = tokenized["input_ids"]
233
+ bs, num_token = input_ids.shape
234
+ # special_tokens_mask: bs, num_token. 1 for special tokens. 0 for normal tokens
235
+ special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool()
236
+ for special_token in special_tokens_list:
237
+ special_tokens_mask |= input_ids == special_token
238
+
239
+ # idxs: each row is a list of indices of special tokens
240
+ idxs = torch.nonzero(special_tokens_mask)
241
+
242
+ # generate attention mask and positional ids
243
+ attention_mask = (
244
+ torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(bs, 1, 1)
245
+ )
246
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
247
+ cate_to_token_mask_list = [[] for _ in range(bs)]
248
+ previous_col = 0
249
+ for i in range(idxs.shape[0]):
250
+ row, col = idxs[i]
251
+ if (col == 0) or (col == num_token - 1):
252
+ attention_mask[row, col, col] = True
253
+ position_ids[row, col] = 0
254
+ else:
255
+ attention_mask[row, previous_col + 1 : col + 1, previous_col + 1 : col + 1] = True
256
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
257
+ 0, col - previous_col, device=input_ids.device
258
+ )
259
+ c2t_maski = torch.zeros((num_token), device=input_ids.device).bool()
260
+ c2t_maski[previous_col + 1 : col] = True
261
+ cate_to_token_mask_list[row].append(c2t_maski)
262
+ previous_col = col
263
+
264
+ cate_to_token_mask_list = [
265
+ torch.stack(cate_to_token_mask_listi, dim=0)
266
+ for cate_to_token_mask_listi in cate_to_token_mask_list
267
+ ]
268
+
269
+ # # padding mask
270
+ # padding_mask = tokenized['attention_mask']
271
+ # attention_mask = attention_mask & padding_mask.unsqueeze(1).bool() & padding_mask.unsqueeze(2).bool()
272
+
273
+ return attention_mask, position_ids.to(torch.long), cate_to_token_mask_list
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor
22
+ ms_deform_attn_forward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const int im2col_step)
29
+ {
30
+ if (value.type().is_cuda())
31
+ {
32
+ #ifdef WITH_CUDA
33
+ return ms_deform_attn_cuda_forward(
34
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
35
+ #else
36
+ AT_ERROR("Not compiled with GPU support");
37
+ #endif
38
+ }
39
+ AT_ERROR("Not implemented on the CPU");
40
+ }
41
+
42
+ std::vector<at::Tensor>
43
+ ms_deform_attn_backward(
44
+ const at::Tensor &value,
45
+ const at::Tensor &spatial_shapes,
46
+ const at::Tensor &level_start_index,
47
+ const at::Tensor &sampling_loc,
48
+ const at::Tensor &attn_weight,
49
+ const at::Tensor &grad_output,
50
+ const int im2col_step)
51
+ {
52
+ if (value.type().is_cuda())
53
+ {
54
+ #ifdef WITH_CUDA
55
+ return ms_deform_attn_cuda_backward(
56
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
57
+ #else
58
+ AT_ERROR("Not compiled with GPU support");
59
+ #endif
60
+ }
61
+ AT_ERROR("Not implemented on the CPU");
62
+ }
63
+
64
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+ namespace groundingdino {
17
+
18
+ at::Tensor
19
+ ms_deform_attn_cpu_forward(
20
+ const at::Tensor &value,
21
+ const at::Tensor &spatial_shapes,
22
+ const at::Tensor &level_start_index,
23
+ const at::Tensor &sampling_loc,
24
+ const at::Tensor &attn_weight,
25
+ const int im2col_step)
26
+ {
27
+ AT_ERROR("Not implement on cpu");
28
+ }
29
+
30
+ std::vector<at::Tensor>
31
+ ms_deform_attn_cpu_backward(
32
+ const at::Tensor &value,
33
+ const at::Tensor &spatial_shapes,
34
+ const at::Tensor &level_start_index,
35
+ const at::Tensor &sampling_loc,
36
+ const at::Tensor &attn_weight,
37
+ const at::Tensor &grad_output,
38
+ const int im2col_step)
39
+ {
40
+ AT_ERROR("Not implement on cpu");
41
+ }
42
+
43
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ namespace groundingdino {
15
+
16
+ at::Tensor
17
+ ms_deform_attn_cpu_forward(
18
+ const at::Tensor &value,
19
+ const at::Tensor &spatial_shapes,
20
+ const at::Tensor &level_start_index,
21
+ const at::Tensor &sampling_loc,
22
+ const at::Tensor &attn_weight,
23
+ const int im2col_step);
24
+
25
+ std::vector<at::Tensor>
26
+ ms_deform_attn_cpu_backward(
27
+ const at::Tensor &value,
28
+ const at::Tensor &spatial_shapes,
29
+ const at::Tensor &level_start_index,
30
+ const at::Tensor &sampling_loc,
31
+ const at::Tensor &attn_weight,
32
+ const at::Tensor &grad_output,
33
+ const int im2col_step);
34
+
35
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ namespace groundingdino {
20
+
21
+ at::Tensor ms_deform_attn_cuda_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
30
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
31
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
32
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
33
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
34
+
35
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
36
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
37
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
38
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
39
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
40
+
41
+ const int batch = value.size(0);
42
+ const int spatial_size = value.size(1);
43
+ const int num_heads = value.size(2);
44
+ const int channels = value.size(3);
45
+
46
+ const int num_levels = spatial_shapes.size(0);
47
+
48
+ const int num_query = sampling_loc.size(1);
49
+ const int num_point = sampling_loc.size(4);
50
+
51
+ const int im2col_step_ = std::min(batch, im2col_step);
52
+
53
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
54
+
55
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
56
+
57
+ const int batch_n = im2col_step_;
58
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
59
+ auto per_value_size = spatial_size * num_heads * channels;
60
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
61
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
62
+ for (int n = 0; n < batch/im2col_step_; ++n)
63
+ {
64
+ auto columns = output_n.select(0, n);
65
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
66
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
67
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
68
+ spatial_shapes.data<int64_t>(),
69
+ level_start_index.data<int64_t>(),
70
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
71
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
72
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
73
+ columns.data<scalar_t>());
74
+
75
+ }));
76
+ }
77
+
78
+ output = output.view({batch, num_query, num_heads*channels});
79
+
80
+ return output;
81
+ }
82
+
83
+
84
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
85
+ const at::Tensor &value,
86
+ const at::Tensor &spatial_shapes,
87
+ const at::Tensor &level_start_index,
88
+ const at::Tensor &sampling_loc,
89
+ const at::Tensor &attn_weight,
90
+ const at::Tensor &grad_output,
91
+ const int im2col_step)
92
+ {
93
+
94
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
95
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
96
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
97
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
98
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
99
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
100
+
101
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
102
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
103
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
104
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
105
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
106
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
107
+
108
+ const int batch = value.size(0);
109
+ const int spatial_size = value.size(1);
110
+ const int num_heads = value.size(2);
111
+ const int channels = value.size(3);
112
+
113
+ const int num_levels = spatial_shapes.size(0);
114
+
115
+ const int num_query = sampling_loc.size(1);
116
+ const int num_point = sampling_loc.size(4);
117
+
118
+ const int im2col_step_ = std::min(batch, im2col_step);
119
+
120
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
121
+
122
+ auto grad_value = at::zeros_like(value);
123
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
124
+ auto grad_attn_weight = at::zeros_like(attn_weight);
125
+
126
+ const int batch_n = im2col_step_;
127
+ auto per_value_size = spatial_size * num_heads * channels;
128
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
129
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
130
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
131
+
132
+ for (int n = 0; n < batch/im2col_step_; ++n)
133
+ {
134
+ auto grad_output_g = grad_output_n.select(0, n);
135
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
136
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
137
+ grad_output_g.data<scalar_t>(),
138
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
139
+ spatial_shapes.data<int64_t>(),
140
+ level_start_index.data<int64_t>(),
141
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
142
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
143
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
144
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
145
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
147
+
148
+ }));
149
+ }
150
+
151
+ return {
152
+ grad_value, grad_sampling_loc, grad_attn_weight
153
+ };
154
+ }
155
+
156
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ namespace groundingdino {
15
+
16
+ at::Tensor ms_deform_attn_cuda_forward(
17
+ const at::Tensor &value,
18
+ const at::Tensor &spatial_shapes,
19
+ const at::Tensor &level_start_index,
20
+ const at::Tensor &sampling_loc,
21
+ const at::Tensor &attn_weight,
22
+ const int im2col_step);
23
+
24
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
33
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #include <cuda_runtime_api.h>
2
+
3
+ namespace groundingdino {
4
+ int get_cudart_version() {
5
+ return CUDART_VERSION;
6
+ }
7
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ #include "MsDeformAttn/ms_deform_attn.h"
4
+
5
+ namespace groundingdino {
6
+
7
+ #ifdef WITH_CUDA
8
+ extern int get_cudart_version();
9
+ #endif
10
+
11
+ std::string get_cuda_version() {
12
+ #ifdef WITH_CUDA
13
+ std::ostringstream oss;
14
+
15
+ // copied from
16
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
17
+ auto printCudaStyleVersion = [&](int v) {
18
+ oss << (v / 1000) << "." << (v / 10 % 100);
19
+ if (v % 10 != 0) {
20
+ oss << "." << (v % 10);
21
+ }
22
+ };
23
+ printCudaStyleVersion(get_cudart_version());
24
+ return oss.str();
25
+ #else
26
+ return std::string("not available");
27
+ #endif
28
+ }
29
+
30
+ // similar to
31
+ // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
32
+ std::string get_compiler_version() {
33
+ std::ostringstream ss;
34
+ #if defined(__GNUC__)
35
+ #ifndef __clang__
36
+ { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
37
+ #endif
38
+ #endif
39
+
40
+ #if defined(__clang_major__)
41
+ {
42
+ ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
43
+ << __clang_patchlevel__;
44
+ }
45
+ #endif
46
+
47
+ #if defined(_MSC_VER)
48
+ { ss << "MSVC " << _MSC_FULL_VER; }
49
+ #endif
50
+ return ss.str();
51
+ }
52
+
53
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
54
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
55
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
56
+ }
57
+
58
+ } // namespace groundingdino
GroundingDINO/groundingdino/models/GroundingDINO/fuse_modules.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from timm.models.layers import DropPath
12
+
13
+
14
+ class FeatureResizer(nn.Module):
15
+ """
16
+ This class takes as input a set of embeddings of dimension C1 and outputs a set of
17
+ embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
18
+ """
19
+
20
+ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
21
+ super().__init__()
22
+ self.do_ln = do_ln
23
+ # Object feature encoding
24
+ self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
25
+ self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ def forward(self, encoder_features):
29
+ x = self.fc(encoder_features)
30
+ if self.do_ln:
31
+ x = self.layer_norm(x)
32
+ output = self.dropout(x)
33
+ return output
34
+
35
+
36
+ def l1norm(X, dim, eps=1e-8):
37
+ """L1-normalize columns of X"""
38
+ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
39
+ X = torch.div(X, norm)
40
+ return X
41
+
42
+
43
+ def l2norm(X, dim, eps=1e-8):
44
+ """L2-normalize columns of X"""
45
+ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
46
+ X = torch.div(X, norm)
47
+ return X
48
+
49
+
50
+ def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
51
+ """
52
+ query: (n_context, queryL, d)
53
+ context: (n_context, sourceL, d)
54
+ """
55
+ batch_size_q, queryL = query.size(0), query.size(1)
56
+ batch_size, sourceL = context.size(0), context.size(1)
57
+
58
+ # Get attention
59
+ # --> (batch, d, queryL)
60
+ queryT = torch.transpose(query, 1, 2)
61
+
62
+ # (batch, sourceL, d)(batch, d, queryL)
63
+ # --> (batch, sourceL, queryL)
64
+ attn = torch.bmm(context, queryT)
65
+ if raw_feature_norm == "softmax":
66
+ # --> (batch*sourceL, queryL)
67
+ attn = attn.view(batch_size * sourceL, queryL)
68
+ attn = nn.Softmax()(attn)
69
+ # --> (batch, sourceL, queryL)
70
+ attn = attn.view(batch_size, sourceL, queryL)
71
+ elif raw_feature_norm == "l2norm":
72
+ attn = l2norm(attn, 2)
73
+ elif raw_feature_norm == "clipped_l2norm":
74
+ attn = nn.LeakyReLU(0.1)(attn)
75
+ attn = l2norm(attn, 2)
76
+ else:
77
+ raise ValueError("unknown first norm type:", raw_feature_norm)
78
+ # --> (batch, queryL, sourceL)
79
+ attn = torch.transpose(attn, 1, 2).contiguous()
80
+ # --> (batch*queryL, sourceL)
81
+ attn = attn.view(batch_size * queryL, sourceL)
82
+ attn = nn.Softmax()(attn * smooth)
83
+ # --> (batch, queryL, sourceL)
84
+ attn = attn.view(batch_size, queryL, sourceL)
85
+ # --> (batch, sourceL, queryL)
86
+ attnT = torch.transpose(attn, 1, 2).contiguous()
87
+
88
+ # --> (batch, d, sourceL)
89
+ contextT = torch.transpose(context, 1, 2)
90
+ # (batch x d x sourceL)(batch x sourceL x queryL)
91
+ # --> (batch, d, queryL)
92
+ weightedContext = torch.bmm(contextT, attnT)
93
+ # --> (batch, queryL, d)
94
+ weightedContext = torch.transpose(weightedContext, 1, 2)
95
+
96
+ return weightedContext, attnT
97
+
98
+
99
+ class BiMultiHeadAttention(nn.Module):
100
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
101
+ super(BiMultiHeadAttention, self).__init__()
102
+
103
+ self.embed_dim = embed_dim
104
+ self.num_heads = num_heads
105
+ self.head_dim = embed_dim // num_heads
106
+ self.v_dim = v_dim
107
+ self.l_dim = l_dim
108
+
109
+ assert (
110
+ self.head_dim * self.num_heads == self.embed_dim
111
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
112
+ self.scale = self.head_dim ** (-0.5)
113
+ self.dropout = dropout
114
+
115
+ self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
116
+ self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
117
+ self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
118
+ self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
119
+
120
+ self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
121
+ self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
122
+
123
+ self.stable_softmax_2d = True
124
+ self.clamp_min_for_underflow = True
125
+ self.clamp_max_for_overflow = True
126
+
127
+ self._reset_parameters()
128
+
129
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
130
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
131
+
132
+ def _reset_parameters(self):
133
+ nn.init.xavier_uniform_(self.v_proj.weight)
134
+ self.v_proj.bias.data.fill_(0)
135
+ nn.init.xavier_uniform_(self.l_proj.weight)
136
+ self.l_proj.bias.data.fill_(0)
137
+ nn.init.xavier_uniform_(self.values_v_proj.weight)
138
+ self.values_v_proj.bias.data.fill_(0)
139
+ nn.init.xavier_uniform_(self.values_l_proj.weight)
140
+ self.values_l_proj.bias.data.fill_(0)
141
+ nn.init.xavier_uniform_(self.out_v_proj.weight)
142
+ self.out_v_proj.bias.data.fill_(0)
143
+ nn.init.xavier_uniform_(self.out_l_proj.weight)
144
+ self.out_l_proj.bias.data.fill_(0)
145
+
146
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
147
+ """_summary_
148
+
149
+ Args:
150
+ v (_type_): bs, n_img, dim
151
+ l (_type_): bs, n_text, dim
152
+ attention_mask_v (_type_, optional): _description_. bs, n_img
153
+ attention_mask_l (_type_, optional): _description_. bs, n_text
154
+
155
+ Returns:
156
+ _type_: _description_
157
+ """
158
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
159
+ # import ipdb; ipdb.set_trace()
160
+ bsz, tgt_len, _ = v.size()
161
+
162
+ query_states = self.v_proj(v) * self.scale
163
+ key_states = self._shape(self.l_proj(l), -1, bsz)
164
+ value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
165
+ value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
166
+
167
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
168
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
169
+ key_states = key_states.view(*proj_shape)
170
+ value_v_states = value_v_states.view(*proj_shape)
171
+ value_l_states = value_l_states.view(*proj_shape)
172
+
173
+ src_len = key_states.size(1)
174
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
175
+
176
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
177
+ raise ValueError(
178
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
179
+ )
180
+
181
+ if self.stable_softmax_2d:
182
+ attn_weights = attn_weights - attn_weights.max()
183
+
184
+ if self.clamp_min_for_underflow:
185
+ attn_weights = torch.clamp(
186
+ attn_weights, min=-50000
187
+ ) # Do not increase -50000, data type half has quite limited range
188
+ if self.clamp_max_for_overflow:
189
+ attn_weights = torch.clamp(
190
+ attn_weights, max=50000
191
+ ) # Do not increase 50000, data type half has quite limited range
192
+
193
+ attn_weights_T = attn_weights.transpose(1, 2)
194
+ attn_weights_l = attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]
195
+ if self.clamp_min_for_underflow:
196
+ attn_weights_l = torch.clamp(
197
+ attn_weights_l, min=-50000
198
+ ) # Do not increase -50000, data type half has quite limited range
199
+ if self.clamp_max_for_overflow:
200
+ attn_weights_l = torch.clamp(
201
+ attn_weights_l, max=50000
202
+ ) # Do not increase 50000, data type half has quite limited range
203
+
204
+ # mask vison for language
205
+ if attention_mask_v is not None:
206
+ attention_mask_v = (
207
+ attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
208
+ )
209
+ attn_weights_l.masked_fill_(attention_mask_v, float("-inf"))
210
+
211
+ attn_weights_l = attn_weights_l.softmax(dim=-1)
212
+
213
+ # mask language for vision
214
+ if attention_mask_l is not None:
215
+ attention_mask_l = (
216
+ attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
217
+ )
218
+ attn_weights.masked_fill_(attention_mask_l, float("-inf"))
219
+ attn_weights_v = attn_weights.softmax(dim=-1)
220
+
221
+ attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
222
+ attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
223
+
224
+ attn_output_v = torch.bmm(attn_probs_v, value_l_states)
225
+ attn_output_l = torch.bmm(attn_probs_l, value_v_states)
226
+
227
+ if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
228
+ raise ValueError(
229
+ f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
230
+ )
231
+
232
+ if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
233
+ raise ValueError(
234
+ f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
235
+ )
236
+
237
+ attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
238
+ attn_output_v = attn_output_v.transpose(1, 2)
239
+ attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
240
+
241
+ attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
242
+ attn_output_l = attn_output_l.transpose(1, 2)
243
+ attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
244
+
245
+ attn_output_v = self.out_v_proj(attn_output_v)
246
+ attn_output_l = self.out_l_proj(attn_output_l)
247
+
248
+ return attn_output_v, attn_output_l
249
+
250
+
251
+ # Bi-Direction MHA (text->image, image->text)
252
+ class BiAttentionBlock(nn.Module):
253
+ def __init__(
254
+ self,
255
+ v_dim,
256
+ l_dim,
257
+ embed_dim,
258
+ num_heads,
259
+ dropout=0.1,
260
+ drop_path=0.0,
261
+ init_values=1e-4,
262
+ cfg=None,
263
+ ):
264
+ """
265
+ Inputs:
266
+ embed_dim - Dimensionality of input and attention feature vectors
267
+ hidden_dim - Dimensionality of hidden layer in feed-forward network
268
+ (usually 2-4x larger than embed_dim)
269
+ num_heads - Number of heads to use in the Multi-Head Attention block
270
+ dropout - Amount of dropout to apply in the feed-forward network
271
+ """
272
+ super(BiAttentionBlock, self).__init__()
273
+
274
+ # pre layer norm
275
+ self.layer_norm_v = nn.LayerNorm(v_dim)
276
+ self.layer_norm_l = nn.LayerNorm(l_dim)
277
+ self.attn = BiMultiHeadAttention(
278
+ v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout
279
+ )
280
+
281
+ # add layer scale for training stability
282
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
283
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=True)
284
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)
285
+
286
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
287
+ v = self.layer_norm_v(v)
288
+ l = self.layer_norm_l(l)
289
+ delta_v, delta_l = self.attn(
290
+ v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
291
+ )
292
+ # v, l = v + delta_v, l + delta_l
293
+ v = v + self.drop_path(self.gamma_v * delta_v)
294
+ l = l + self.drop_path(self.gamma_l * delta_l)
295
+ return v, l
296
+
297
+ # def forward(self, v:List[torch.Tensor], l, attention_mask_v=None, attention_mask_l=None)
GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR model and criterion classes.
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Modified from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+ # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
15
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
16
+ # ------------------------------------------------------------------------
17
+ import copy
18
+ from typing import List
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+ from torchvision.ops.boxes import nms
24
+ from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
25
+
26
+ from groundingdino.util import box_ops, get_tokenlizer
27
+ from groundingdino.util.misc import (
28
+ NestedTensor,
29
+ accuracy,
30
+ get_world_size,
31
+ interpolate,
32
+ inverse_sigmoid,
33
+ is_dist_avail_and_initialized,
34
+ nested_tensor_from_tensor_list,
35
+ )
36
+ from groundingdino.util.utils import get_phrases_from_posmap
37
+ from groundingdino.util.visualizer import COCOVisualizer
38
+ from groundingdino.util.vl_utils import create_positive_map_from_span
39
+
40
+ from ..registry import MODULE_BUILD_FUNCS
41
+ from .backbone import build_backbone
42
+ from .bertwarper import (
43
+ BertModelWarper,
44
+ generate_masks_with_special_tokens,
45
+ generate_masks_with_special_tokens_and_transfer_map,
46
+ )
47
+ from .transformer import build_transformer
48
+ from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
49
+
50
+
51
+ class GroundingDINO(nn.Module):
52
+ """This is the Cross-Attention Detector module that performs object detection"""
53
+
54
+ def __init__(
55
+ self,
56
+ backbone,
57
+ transformer,
58
+ num_queries,
59
+ aux_loss=False,
60
+ iter_update=False,
61
+ query_dim=2,
62
+ num_feature_levels=1,
63
+ nheads=8,
64
+ # two stage
65
+ two_stage_type="no", # ['no', 'standard']
66
+ dec_pred_bbox_embed_share=True,
67
+ two_stage_class_embed_share=True,
68
+ two_stage_bbox_embed_share=True,
69
+ num_patterns=0,
70
+ dn_number=100,
71
+ dn_box_noise_scale=0.4,
72
+ dn_label_noise_ratio=0.5,
73
+ dn_labelbook_size=100,
74
+ text_encoder_type="bert-base-uncased",
75
+ sub_sentence_present=True,
76
+ max_text_len=256,
77
+ ):
78
+ """Initializes the model.
79
+ Parameters:
80
+ backbone: torch module of the backbone to be used. See backbone.py
81
+ transformer: torch module of the transformer architecture. See transformer.py
82
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
83
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
84
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
85
+ """
86
+ super().__init__()
87
+ self.num_queries = num_queries
88
+ self.transformer = transformer
89
+ self.hidden_dim = hidden_dim = transformer.d_model
90
+ self.num_feature_levels = num_feature_levels
91
+ self.nheads = nheads
92
+ self.max_text_len = 256
93
+ self.sub_sentence_present = sub_sentence_present
94
+
95
+ # setting query dim
96
+ self.query_dim = query_dim
97
+ assert query_dim == 4
98
+
99
+ # for dn training
100
+ self.num_patterns = num_patterns
101
+ self.dn_number = dn_number
102
+ self.dn_box_noise_scale = dn_box_noise_scale
103
+ self.dn_label_noise_ratio = dn_label_noise_ratio
104
+ self.dn_labelbook_size = dn_labelbook_size
105
+
106
+ # bert
107
+ self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
108
+ self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
109
+ self.bert.pooler.dense.weight.requires_grad_(False)
110
+ self.bert.pooler.dense.bias.requires_grad_(False)
111
+ self.bert = BertModelWarper(bert_model=self.bert)
112
+
113
+ self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
114
+ nn.init.constant_(self.feat_map.bias.data, 0)
115
+ nn.init.xavier_uniform_(self.feat_map.weight.data)
116
+ # freeze
117
+
118
+ # special tokens
119
+ self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
120
+
121
+ # prepare input projection layers
122
+ if num_feature_levels > 1:
123
+ num_backbone_outs = len(backbone.num_channels)
124
+ input_proj_list = []
125
+ for _ in range(num_backbone_outs):
126
+ in_channels = backbone.num_channels[_]
127
+ input_proj_list.append(
128
+ nn.Sequential(
129
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
130
+ nn.GroupNorm(32, hidden_dim),
131
+ )
132
+ )
133
+ for _ in range(num_feature_levels - num_backbone_outs):
134
+ input_proj_list.append(
135
+ nn.Sequential(
136
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
137
+ nn.GroupNorm(32, hidden_dim),
138
+ )
139
+ )
140
+ in_channels = hidden_dim
141
+ self.input_proj = nn.ModuleList(input_proj_list)
142
+ else:
143
+ assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
144
+ self.input_proj = nn.ModuleList(
145
+ [
146
+ nn.Sequential(
147
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
148
+ nn.GroupNorm(32, hidden_dim),
149
+ )
150
+ ]
151
+ )
152
+
153
+ self.backbone = backbone
154
+ self.aux_loss = aux_loss
155
+ self.box_pred_damping = box_pred_damping = None
156
+
157
+ self.iter_update = iter_update
158
+ assert iter_update, "Why not iter_update?"
159
+
160
+ # prepare pred layers
161
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
162
+ # prepare class & box embed
163
+ _class_embed = ContrastiveEmbed()
164
+
165
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
166
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
167
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
168
+
169
+ if dec_pred_bbox_embed_share:
170
+ box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
171
+ else:
172
+ box_embed_layerlist = [
173
+ copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
174
+ ]
175
+ class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
176
+ self.bbox_embed = nn.ModuleList(box_embed_layerlist)
177
+ self.class_embed = nn.ModuleList(class_embed_layerlist)
178
+ self.transformer.decoder.bbox_embed = self.bbox_embed
179
+ self.transformer.decoder.class_embed = self.class_embed
180
+
181
+ # two stage
182
+ self.two_stage_type = two_stage_type
183
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
184
+ two_stage_type
185
+ )
186
+ if two_stage_type != "no":
187
+ if two_stage_bbox_embed_share:
188
+ assert dec_pred_bbox_embed_share
189
+ self.transformer.enc_out_bbox_embed = _bbox_embed
190
+ else:
191
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
192
+
193
+ if two_stage_class_embed_share:
194
+ assert dec_pred_bbox_embed_share
195
+ self.transformer.enc_out_class_embed = _class_embed
196
+ else:
197
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
198
+
199
+ self.refpoint_embed = None
200
+
201
+ self._reset_parameters()
202
+
203
+ def _reset_parameters(self):
204
+ # init input_proj
205
+ for proj in self.input_proj:
206
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
207
+ nn.init.constant_(proj[0].bias, 0)
208
+
209
+ def init_ref_points(self, use_num_queries):
210
+ self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
211
+
212
+ def forward(self, samples: NestedTensor, targets: List = None, **kw):
213
+ """The forward expects a NestedTensor, which consists of:
214
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
215
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
216
+
217
+ It returns a dict with the following elements:
218
+ - "pred_logits": the classification logits (including no-object) for all queries.
219
+ Shape= [batch_size x num_queries x num_classes]
220
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
221
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
222
+ relative to the size of each individual image (disregarding possible padding).
223
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
224
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
225
+ dictionnaries containing the two above keys for each decoder layer.
226
+ """
227
+ if targets is None:
228
+ captions = kw["captions"]
229
+ else:
230
+ captions = [t["caption"] for t in targets]
231
+ len(captions)
232
+
233
+ # encoder texts
234
+ tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
235
+ samples.device
236
+ )
237
+ (
238
+ text_self_attention_masks,
239
+ position_ids,
240
+ cate_to_token_mask_list,
241
+ ) = generate_masks_with_special_tokens_and_transfer_map(
242
+ tokenized, self.specical_tokens, self.tokenizer
243
+ )
244
+
245
+ if text_self_attention_masks.shape[1] > self.max_text_len:
246
+ text_self_attention_masks = text_self_attention_masks[
247
+ :, : self.max_text_len, : self.max_text_len
248
+ ]
249
+ position_ids = position_ids[:, : self.max_text_len]
250
+ tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
251
+ tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
252
+ tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
253
+
254
+ # extract text embeddings
255
+ if self.sub_sentence_present:
256
+ tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
257
+ tokenized_for_encoder["attention_mask"] = text_self_attention_masks
258
+ tokenized_for_encoder["position_ids"] = position_ids
259
+ else:
260
+ # import ipdb; ipdb.set_trace()
261
+ tokenized_for_encoder = tokenized
262
+
263
+ bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
264
+
265
+ encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
266
+ text_token_mask = tokenized.attention_mask.bool() # bs, 195
267
+ # text_token_mask: True for nomask, False for mask
268
+ # text_self_attention_masks: True for nomask, False for mask
269
+
270
+ if encoded_text.shape[1] > self.max_text_len:
271
+ encoded_text = encoded_text[:, : self.max_text_len, :]
272
+ text_token_mask = text_token_mask[:, : self.max_text_len]
273
+ position_ids = position_ids[:, : self.max_text_len]
274
+ text_self_attention_masks = text_self_attention_masks[
275
+ :, : self.max_text_len, : self.max_text_len
276
+ ]
277
+
278
+ text_dict = {
279
+ "encoded_text": encoded_text, # bs, 195, d_model
280
+ "text_token_mask": text_token_mask, # bs, 195
281
+ "position_ids": position_ids, # bs, 195
282
+ "text_self_attention_masks": text_self_attention_masks, # bs, 195,195
283
+ }
284
+
285
+ # import ipdb; ipdb.set_trace()
286
+
287
+ if isinstance(samples, (list, torch.Tensor)):
288
+ samples = nested_tensor_from_tensor_list(samples)
289
+ features, poss = self.backbone(samples)
290
+
291
+ srcs = []
292
+ masks = []
293
+ for l, feat in enumerate(features):
294
+ src, mask = feat.decompose()
295
+ srcs.append(self.input_proj[l](src))
296
+ masks.append(mask)
297
+ assert mask is not None
298
+ if self.num_feature_levels > len(srcs):
299
+ _len_srcs = len(srcs)
300
+ for l in range(_len_srcs, self.num_feature_levels):
301
+ if l == _len_srcs:
302
+ src = self.input_proj[l](features[-1].tensors)
303
+ else:
304
+ src = self.input_proj[l](srcs[-1])
305
+ m = samples.mask
306
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
307
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
308
+ srcs.append(src)
309
+ masks.append(mask)
310
+ poss.append(pos_l)
311
+
312
+ input_query_bbox = input_query_label = attn_mask = dn_meta = None
313
+ hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
314
+ srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
315
+ )
316
+
317
+ # deformable-detr-like anchor update
318
+ outputs_coord_list = []
319
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
320
+ zip(reference[:-1], self.bbox_embed, hs)
321
+ ):
322
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
323
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
324
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
325
+ outputs_coord_list.append(layer_outputs_unsig)
326
+ outputs_coord_list = torch.stack(outputs_coord_list)
327
+
328
+ # output
329
+ outputs_class = torch.stack(
330
+ [
331
+ layer_cls_embed(layer_hs, text_dict)
332
+ for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
333
+ ]
334
+ )
335
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
336
+
337
+ # # for intermediate outputs
338
+ # if self.aux_loss:
339
+ # out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
340
+
341
+ # # for encoder output
342
+ # if hs_enc is not None:
343
+ # # prepare intermediate outputs
344
+ # interm_coord = ref_enc[-1]
345
+ # interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
346
+ # out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
347
+ # out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
348
+
349
+ return out
350
+
351
+ @torch.jit.unused
352
+ def _set_aux_loss(self, outputs_class, outputs_coord):
353
+ # this is a workaround to make torchscript happy, as torchscript
354
+ # doesn't support dictionary with non-homogeneous values, such
355
+ # as a dict having both a Tensor and a list.
356
+ return [
357
+ {"pred_logits": a, "pred_boxes": b}
358
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
359
+ ]
360
+
361
+
362
+ @MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
363
+ def build_groundingdino(args):
364
+
365
+ backbone = build_backbone(args)
366
+ transformer = build_transformer(args)
367
+
368
+ dn_labelbook_size = args.dn_labelbook_size
369
+ dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
370
+ sub_sentence_present = args.sub_sentence_present
371
+
372
+ model = GroundingDINO(
373
+ backbone,
374
+ transformer,
375
+ num_queries=args.num_queries,
376
+ aux_loss=True,
377
+ iter_update=True,
378
+ query_dim=4,
379
+ num_feature_levels=args.num_feature_levels,
380
+ nheads=args.nheads,
381
+ dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
382
+ two_stage_type=args.two_stage_type,
383
+ two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
384
+ two_stage_class_embed_share=args.two_stage_class_embed_share,
385
+ num_patterns=args.num_patterns,
386
+ dn_number=0,
387
+ dn_box_noise_scale=args.dn_box_noise_scale,
388
+ dn_label_noise_ratio=args.dn_label_noise_ratio,
389
+ dn_labelbook_size=dn_labelbook_size,
390
+ text_encoder_type=args.text_encoder_type,
391
+ sub_sentence_present=sub_sentence_present,
392
+ max_text_len=args.max_text_len,
393
+ )
394
+
395
+ return model
GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Deformable DETR
8
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------------------------------
11
+ # Modified from:
12
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py
13
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
14
+ # https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py
15
+ # ------------------------------------------------------------------------------------------------
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.autograd import Function
25
+ from torch.autograd.function import once_differentiable
26
+ from torch.nn.init import constant_, xavier_uniform_
27
+
28
+ try:
29
+ from groundingdino import _C
30
+ except:
31
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
32
+
33
+
34
+ # helpers
35
+ def _is_power_of_2(n):
36
+ if (not isinstance(n, int)) or (n < 0):
37
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
38
+ return (n & (n - 1) == 0) and n != 0
39
+
40
+
41
+ class MultiScaleDeformableAttnFunction(Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ value,
46
+ value_spatial_shapes,
47
+ value_level_start_index,
48
+ sampling_locations,
49
+ attention_weights,
50
+ im2col_step,
51
+ ):
52
+ ctx.im2col_step = im2col_step
53
+ output = _C.ms_deform_attn_forward(
54
+ value,
55
+ value_spatial_shapes,
56
+ value_level_start_index,
57
+ sampling_locations,
58
+ attention_weights,
59
+ ctx.im2col_step,
60
+ )
61
+ ctx.save_for_backward(
62
+ value,
63
+ value_spatial_shapes,
64
+ value_level_start_index,
65
+ sampling_locations,
66
+ attention_weights,
67
+ )
68
+ return output
69
+
70
+ @staticmethod
71
+ @once_differentiable
72
+ def backward(ctx, grad_output):
73
+ (
74
+ value,
75
+ value_spatial_shapes,
76
+ value_level_start_index,
77
+ sampling_locations,
78
+ attention_weights,
79
+ ) = ctx.saved_tensors
80
+ grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward(
81
+ value,
82
+ value_spatial_shapes,
83
+ value_level_start_index,
84
+ sampling_locations,
85
+ attention_weights,
86
+ grad_output,
87
+ ctx.im2col_step,
88
+ )
89
+
90
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
91
+
92
+
93
+ def multi_scale_deformable_attn_pytorch(
94
+ value: torch.Tensor,
95
+ value_spatial_shapes: torch.Tensor,
96
+ sampling_locations: torch.Tensor,
97
+ attention_weights: torch.Tensor,
98
+ ) -> torch.Tensor:
99
+
100
+ bs, _, num_heads, embed_dims = value.shape
101
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
102
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
103
+ sampling_grids = 2 * sampling_locations - 1
104
+ sampling_value_list = []
105
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
106
+ # bs, H_*W_, num_heads, embed_dims ->
107
+ # bs, H_*W_, num_heads*embed_dims ->
108
+ # bs, num_heads*embed_dims, H_*W_ ->
109
+ # bs*num_heads, embed_dims, H_, W_
110
+ value_l_ = (
111
+ value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
112
+ )
113
+ # bs, num_queries, num_heads, num_points, 2 ->
114
+ # bs, num_heads, num_queries, num_points, 2 ->
115
+ # bs*num_heads, num_queries, num_points, 2
116
+ sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
117
+ # bs*num_heads, embed_dims, num_queries, num_points
118
+ sampling_value_l_ = F.grid_sample(
119
+ value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
120
+ )
121
+ sampling_value_list.append(sampling_value_l_)
122
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
123
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
124
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
125
+ attention_weights = attention_weights.transpose(1, 2).reshape(
126
+ bs * num_heads, 1, num_queries, num_levels * num_points
127
+ )
128
+ output = (
129
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
130
+ .sum(-1)
131
+ .view(bs, num_heads * embed_dims, num_queries)
132
+ )
133
+ return output.transpose(1, 2).contiguous()
134
+
135
+
136
+ class MultiScaleDeformableAttention(nn.Module):
137
+ """Multi-Scale Deformable Attention Module used in Deformable-DETR
138
+
139
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
140
+ <https://arxiv.org/pdf/2010.04159.pdf>`_.
141
+
142
+ Args:
143
+ embed_dim (int): The embedding dimension of Attention. Default: 256.
144
+ num_heads (int): The number of attention heads. Default: 8.
145
+ num_levels (int): The number of feature map used in Attention. Default: 4.
146
+ num_points (int): The number of sampling points for each query
147
+ in each head. Default: 4.
148
+ img2col_steps (int): The step used in image_to_column. Defualt: 64.
149
+ dropout (float): Dropout layer used in output. Default: 0.1.
150
+ batch_first (bool): if ``True``, then the input and output tensor will be
151
+ provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ embed_dim: int = 256,
157
+ num_heads: int = 8,
158
+ num_levels: int = 4,
159
+ num_points: int = 4,
160
+ img2col_step: int = 64,
161
+ batch_first: bool = False,
162
+ ):
163
+ super().__init__()
164
+ if embed_dim % num_heads != 0:
165
+ raise ValueError(
166
+ "embed_dim must be divisible by num_heads, but got {} and {}".format(
167
+ embed_dim, num_heads
168
+ )
169
+ )
170
+ head_dim = embed_dim // num_heads
171
+
172
+ self.batch_first = batch_first
173
+
174
+ if not _is_power_of_2(head_dim):
175
+ warnings.warn(
176
+ """
177
+ You'd better set d_model in MSDeformAttn to make sure that
178
+ each dim of the attention head a power of 2, which is more efficient.
179
+ """
180
+ )
181
+
182
+ self.im2col_step = img2col_step
183
+ self.embed_dim = embed_dim
184
+ self.num_heads = num_heads
185
+ self.num_levels = num_levels
186
+ self.num_points = num_points
187
+ self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2)
188
+ self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points)
189
+ self.value_proj = nn.Linear(embed_dim, embed_dim)
190
+ self.output_proj = nn.Linear(embed_dim, embed_dim)
191
+
192
+ self.init_weights()
193
+
194
+ def _reset_parameters(self):
195
+ return self.init_weights()
196
+
197
+ def init_weights(self):
198
+ """
199
+ Default initialization for Parameters of Module.
200
+ """
201
+ constant_(self.sampling_offsets.weight.data, 0.0)
202
+ thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
203
+ 2.0 * math.pi / self.num_heads
204
+ )
205
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
206
+ grid_init = (
207
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
208
+ .view(self.num_heads, 1, 1, 2)
209
+ .repeat(1, self.num_levels, self.num_points, 1)
210
+ )
211
+ for i in range(self.num_points):
212
+ grid_init[:, :, i, :] *= i + 1
213
+ with torch.no_grad():
214
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
215
+ constant_(self.attention_weights.weight.data, 0.0)
216
+ constant_(self.attention_weights.bias.data, 0.0)
217
+ xavier_uniform_(self.value_proj.weight.data)
218
+ constant_(self.value_proj.bias.data, 0.0)
219
+ xavier_uniform_(self.output_proj.weight.data)
220
+ constant_(self.output_proj.bias.data, 0.0)
221
+
222
+ def freeze_sampling_offsets(self):
223
+ print("Freeze sampling offsets")
224
+ self.sampling_offsets.weight.requires_grad = False
225
+ self.sampling_offsets.bias.requires_grad = False
226
+
227
+ def freeze_attention_weights(self):
228
+ print("Freeze attention weights")
229
+ self.attention_weights.weight.requires_grad = False
230
+ self.attention_weights.bias.requires_grad = False
231
+
232
+ def forward(
233
+ self,
234
+ query: torch.Tensor,
235
+ key: Optional[torch.Tensor] = None,
236
+ value: Optional[torch.Tensor] = None,
237
+ query_pos: Optional[torch.Tensor] = None,
238
+ key_padding_mask: Optional[torch.Tensor] = None,
239
+ reference_points: Optional[torch.Tensor] = None,
240
+ spatial_shapes: Optional[torch.Tensor] = None,
241
+ level_start_index: Optional[torch.Tensor] = None,
242
+ **kwargs
243
+ ) -> torch.Tensor:
244
+
245
+ """Forward Function of MultiScaleDeformableAttention
246
+
247
+ Args:
248
+ query (torch.Tensor): Query embeddings with shape
249
+ `(num_query, bs, embed_dim)`
250
+ key (torch.Tensor): Key embeddings with shape
251
+ `(num_key, bs, embed_dim)`
252
+ value (torch.Tensor): Value embeddings with shape
253
+ `(num_key, bs, embed_dim)`
254
+ query_pos (torch.Tensor): The position embedding for `query`. Default: None.
255
+ key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`,
256
+ indicating which elements within `key` to be ignored in attention.
257
+ reference_points (torch.Tensor): The normalized reference points
258
+ with shape `(bs, num_query, num_levels, 2)`,
259
+ all elements is range in [0, 1], top-left (0, 0),
260
+ bottom-right (1, 1), including padding are.
261
+ or `(N, Length_{query}, num_levels, 4)`, add additional
262
+ two dimensions `(h, w)` to form reference boxes.
263
+ spatial_shapes (torch.Tensor): Spatial shape of features in different levels.
264
+ With shape `(num_levels, 2)`, last dimension represents `(h, w)`.
265
+ level_start_index (torch.Tensor): The start index of each level. A tensor with
266
+ shape `(num_levels, )` which can be represented as
267
+ `[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`.
268
+
269
+ Returns:
270
+ torch.Tensor: forward results with shape `(num_query, bs, embed_dim)`
271
+ """
272
+
273
+ if value is None:
274
+ value = query
275
+
276
+ if query_pos is not None:
277
+ query = query + query_pos
278
+
279
+ if not self.batch_first:
280
+ # change to (bs, num_query ,embed_dims)
281
+ query = query.permute(1, 0, 2)
282
+ value = value.permute(1, 0, 2)
283
+
284
+ bs, num_query, _ = query.shape
285
+ bs, num_value, _ = value.shape
286
+
287
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
288
+
289
+ value = self.value_proj(value)
290
+ if key_padding_mask is not None:
291
+ value = value.masked_fill(key_padding_mask[..., None], float(0))
292
+ value = value.view(bs, num_value, self.num_heads, -1)
293
+ sampling_offsets = self.sampling_offsets(query).view(
294
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
295
+ )
296
+ attention_weights = self.attention_weights(query).view(
297
+ bs, num_query, self.num_heads, self.num_levels * self.num_points
298
+ )
299
+ attention_weights = attention_weights.softmax(-1)
300
+ attention_weights = attention_weights.view(
301
+ bs,
302
+ num_query,
303
+ self.num_heads,
304
+ self.num_levels,
305
+ self.num_points,
306
+ )
307
+
308
+ # bs, num_query, num_heads, num_levels, num_points, 2
309
+ if reference_points.shape[-1] == 2:
310
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
311
+ sampling_locations = (
312
+ reference_points[:, :, None, :, None, :]
313
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
314
+ )
315
+ elif reference_points.shape[-1] == 4:
316
+ sampling_locations = (
317
+ reference_points[:, :, None, :, None, :2]
318
+ + sampling_offsets
319
+ / self.num_points
320
+ * reference_points[:, :, None, :, None, 2:]
321
+ * 0.5
322
+ )
323
+ else:
324
+ raise ValueError(
325
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
326
+ reference_points.shape[-1]
327
+ )
328
+ )
329
+
330
+ if torch.cuda.is_available() and value.is_cuda:
331
+ halffloat = False
332
+ if value.dtype == torch.float16:
333
+ halffloat = True
334
+ value = value.float()
335
+ sampling_locations = sampling_locations.float()
336
+ attention_weights = attention_weights.float()
337
+
338
+ output = MultiScaleDeformableAttnFunction.apply(
339
+ value,
340
+ spatial_shapes,
341
+ level_start_index,
342
+ sampling_locations,
343
+ attention_weights,
344
+ self.im2col_step,
345
+ )
346
+
347
+ if halffloat:
348
+ output = output.half()
349
+ else:
350
+ output = multi_scale_deformable_attn_pytorch(
351
+ value, spatial_shapes, sampling_locations, attention_weights
352
+ )
353
+
354
+ output = self.output_proj(output)
355
+
356
+ if not self.batch_first:
357
+ output = output.permute(1, 0, 2)
358
+
359
+ return output
360
+
361
+
362
+ def create_dummy_class(klass, dependency, message=""):
363
+ """
364
+ When a dependency of a class is not available, create a dummy class which throws ImportError
365
+ when used.
366
+
367
+ Args:
368
+ klass (str): name of the class.
369
+ dependency (str): name of the dependency.
370
+ message: extra message to print
371
+ Returns:
372
+ class: a class object
373
+ """
374
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass)
375
+ if message:
376
+ err = err + " " + message
377
+
378
+ class _DummyMetaClass(type):
379
+ # throw error on class attribute access
380
+ def __getattr__(_, __): # noqa: B902
381
+ raise ImportError(err)
382
+
383
+ class _Dummy(object, metaclass=_DummyMetaClass):
384
+ # throw error on constructor
385
+ def __init__(self, *args, **kwargs):
386
+ raise ImportError(err)
387
+
388
+ return _Dummy
389
+
390
+
391
+ def create_dummy_func(func, dependency, message=""):
392
+ """
393
+ When a dependency of a function is not available, create a dummy function which throws
394
+ ImportError when used.
395
+
396
+ Args:
397
+ func (str): name of the function.
398
+ dependency (str or list[str]): name(s) of the dependency.
399
+ message: extra message to print
400
+ Returns:
401
+ function: a function object
402
+ """
403
+ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func)
404
+ if message:
405
+ err = err + " " + message
406
+
407
+ if isinstance(dependency, (list, tuple)):
408
+ dependency = ",".join(dependency)
409
+
410
+ def _dummy(*args, **kwargs):
411
+ raise ImportError(err)
412
+
413
+ return _dummy
GroundingDINO/groundingdino/models/GroundingDINO/transformer.py ADDED
@@ -0,0 +1,959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # DINO
8
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR Transformer class.
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Modified from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch.utils.checkpoint as checkpoint
23
+ from torch import Tensor, nn
24
+
25
+ from groundingdino.util.misc import inverse_sigmoid
26
+
27
+ from .fuse_modules import BiAttentionBlock
28
+ from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn
29
+ from .transformer_vanilla import TransformerEncoderLayer
30
+ from .utils import (
31
+ MLP,
32
+ _get_activation_fn,
33
+ _get_clones,
34
+ gen_encoder_output_proposals,
35
+ gen_sineembed_for_position,
36
+ get_sine_pos_embed,
37
+ )
38
+
39
+
40
+ class Transformer(nn.Module):
41
+ def __init__(
42
+ self,
43
+ d_model=256,
44
+ nhead=8,
45
+ num_queries=300,
46
+ num_encoder_layers=6,
47
+ num_unicoder_layers=0,
48
+ num_decoder_layers=6,
49
+ dim_feedforward=2048,
50
+ dropout=0.0,
51
+ activation="relu",
52
+ normalize_before=False,
53
+ return_intermediate_dec=False,
54
+ query_dim=4,
55
+ num_patterns=0,
56
+ # for deformable encoder
57
+ num_feature_levels=1,
58
+ enc_n_points=4,
59
+ dec_n_points=4,
60
+ # init query
61
+ learnable_tgt_init=False,
62
+ # two stage
63
+ two_stage_type="no", # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
64
+ embed_init_tgt=False,
65
+ # for text
66
+ use_text_enhancer=False,
67
+ use_fusion_layer=False,
68
+ use_checkpoint=False,
69
+ use_transformer_ckpt=False,
70
+ use_text_cross_attention=False,
71
+ text_dropout=0.1,
72
+ fusion_dropout=0.1,
73
+ fusion_droppath=0.0,
74
+ ):
75
+ super().__init__()
76
+ self.num_feature_levels = num_feature_levels
77
+ self.num_encoder_layers = num_encoder_layers
78
+ self.num_unicoder_layers = num_unicoder_layers
79
+ self.num_decoder_layers = num_decoder_layers
80
+ self.num_queries = num_queries
81
+ assert query_dim == 4
82
+
83
+ # choose encoder layer type
84
+ encoder_layer = DeformableTransformerEncoderLayer(
85
+ d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
86
+ )
87
+
88
+ if use_text_enhancer:
89
+ text_enhance_layer = TransformerEncoderLayer(
90
+ d_model=d_model,
91
+ nhead=nhead // 2,
92
+ dim_feedforward=dim_feedforward // 2,
93
+ dropout=text_dropout,
94
+ )
95
+ else:
96
+ text_enhance_layer = None
97
+
98
+ if use_fusion_layer:
99
+ feature_fusion_layer = BiAttentionBlock(
100
+ v_dim=d_model,
101
+ l_dim=d_model,
102
+ embed_dim=dim_feedforward // 2,
103
+ num_heads=nhead // 2,
104
+ dropout=fusion_dropout,
105
+ drop_path=fusion_droppath,
106
+ )
107
+ else:
108
+ feature_fusion_layer = None
109
+
110
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
111
+ assert encoder_norm is None
112
+ self.encoder = TransformerEncoder(
113
+ encoder_layer,
114
+ num_encoder_layers,
115
+ d_model=d_model,
116
+ num_queries=num_queries,
117
+ text_enhance_layer=text_enhance_layer,
118
+ feature_fusion_layer=feature_fusion_layer,
119
+ use_checkpoint=use_checkpoint,
120
+ use_transformer_ckpt=use_transformer_ckpt,
121
+ )
122
+
123
+ # choose decoder layer type
124
+ decoder_layer = DeformableTransformerDecoderLayer(
125
+ d_model,
126
+ dim_feedforward,
127
+ dropout,
128
+ activation,
129
+ num_feature_levels,
130
+ nhead,
131
+ dec_n_points,
132
+ use_text_cross_attention=use_text_cross_attention,
133
+ )
134
+
135
+ decoder_norm = nn.LayerNorm(d_model)
136
+ self.decoder = TransformerDecoder(
137
+ decoder_layer,
138
+ num_decoder_layers,
139
+ decoder_norm,
140
+ return_intermediate=return_intermediate_dec,
141
+ d_model=d_model,
142
+ query_dim=query_dim,
143
+ num_feature_levels=num_feature_levels,
144
+ )
145
+
146
+ self.d_model = d_model
147
+ self.nhead = nhead
148
+ self.dec_layers = num_decoder_layers
149
+ self.num_queries = num_queries # useful for single stage model only
150
+ self.num_patterns = num_patterns
151
+ if not isinstance(num_patterns, int):
152
+ Warning("num_patterns should be int but {}".format(type(num_patterns)))
153
+ self.num_patterns = 0
154
+
155
+ if num_feature_levels > 1:
156
+ if self.num_encoder_layers > 0:
157
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
158
+ else:
159
+ self.level_embed = None
160
+
161
+ self.learnable_tgt_init = learnable_tgt_init
162
+ assert learnable_tgt_init, "why not learnable_tgt_init"
163
+ self.embed_init_tgt = embed_init_tgt
164
+ if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"):
165
+ self.tgt_embed = nn.Embedding(self.num_queries, d_model)
166
+ nn.init.normal_(self.tgt_embed.weight.data)
167
+ else:
168
+ self.tgt_embed = None
169
+
170
+ # for two stage
171
+ self.two_stage_type = two_stage_type
172
+ assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
173
+ two_stage_type
174
+ )
175
+ if two_stage_type == "standard":
176
+ # anchor selection at the output of encoder
177
+ self.enc_output = nn.Linear(d_model, d_model)
178
+ self.enc_output_norm = nn.LayerNorm(d_model)
179
+ self.two_stage_wh_embedding = None
180
+
181
+ if two_stage_type == "no":
182
+ self.init_ref_points(num_queries) # init self.refpoint_embed
183
+
184
+ self.enc_out_class_embed = None
185
+ self.enc_out_bbox_embed = None
186
+
187
+ self._reset_parameters()
188
+
189
+ def _reset_parameters(self):
190
+ for p in self.parameters():
191
+ if p.dim() > 1:
192
+ nn.init.xavier_uniform_(p)
193
+ for m in self.modules():
194
+ if isinstance(m, MSDeformAttn):
195
+ m._reset_parameters()
196
+ if self.num_feature_levels > 1 and self.level_embed is not None:
197
+ nn.init.normal_(self.level_embed)
198
+
199
+ def get_valid_ratio(self, mask):
200
+ _, H, W = mask.shape
201
+ valid_H = torch.sum(~mask[:, :, 0], 1)
202
+ valid_W = torch.sum(~mask[:, 0, :], 1)
203
+ valid_ratio_h = valid_H.float() / H
204
+ valid_ratio_w = valid_W.float() / W
205
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
206
+ return valid_ratio
207
+
208
+ def init_ref_points(self, use_num_queries):
209
+ self.refpoint_embed = nn.Embedding(use_num_queries, 4)
210
+
211
+ def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
212
+ """
213
+ Input:
214
+ - srcs: List of multi features [bs, ci, hi, wi]
215
+ - masks: List of multi masks [bs, hi, wi]
216
+ - refpoint_embed: [bs, num_dn, 4]. None in infer
217
+ - pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
218
+ - tgt: [bs, num_dn, d_model]. None in infer
219
+
220
+ """
221
+ # prepare input for encoder
222
+ src_flatten = []
223
+ mask_flatten = []
224
+ lvl_pos_embed_flatten = []
225
+ spatial_shapes = []
226
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
227
+ bs, c, h, w = src.shape
228
+ spatial_shape = (h, w)
229
+ spatial_shapes.append(spatial_shape)
230
+
231
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
232
+ mask = mask.flatten(1) # bs, hw
233
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
234
+ if self.num_feature_levels > 1 and self.level_embed is not None:
235
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
236
+ else:
237
+ lvl_pos_embed = pos_embed
238
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
239
+ src_flatten.append(src)
240
+ mask_flatten.append(mask)
241
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
242
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
243
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
244
+ spatial_shapes = torch.as_tensor(
245
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
246
+ )
247
+ level_start_index = torch.cat(
248
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
249
+ )
250
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
251
+
252
+ # two stage
253
+ enc_topk_proposals = enc_refpoint_embed = None
254
+
255
+ #########################################################
256
+ # Begin Encoder
257
+ #########################################################
258
+ memory, memory_text = self.encoder(
259
+ src_flatten,
260
+ pos=lvl_pos_embed_flatten,
261
+ level_start_index=level_start_index,
262
+ spatial_shapes=spatial_shapes,
263
+ valid_ratios=valid_ratios,
264
+ key_padding_mask=mask_flatten,
265
+ memory_text=text_dict["encoded_text"],
266
+ text_attention_mask=~text_dict["text_token_mask"],
267
+ # we ~ the mask . False means use the token; True means pad the token
268
+ position_ids=text_dict["position_ids"],
269
+ text_self_attention_masks=text_dict["text_self_attention_masks"],
270
+ )
271
+ #########################################################
272
+ # End Encoder
273
+ # - memory: bs, \sum{hw}, c
274
+ # - mask_flatten: bs, \sum{hw}
275
+ # - lvl_pos_embed_flatten: bs, \sum{hw}, c
276
+ # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
277
+ # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
278
+ #########################################################
279
+ text_dict["encoded_text"] = memory_text
280
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
281
+ # if memory.isnan().any() | memory.isinf().any():
282
+ # import ipdb; ipdb.set_trace()
283
+
284
+ if self.two_stage_type == "standard":
285
+ output_memory, output_proposals = gen_encoder_output_proposals(
286
+ memory, mask_flatten, spatial_shapes
287
+ )
288
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
289
+
290
+ if text_dict is not None:
291
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
292
+ else:
293
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
294
+
295
+ topk_logits = enc_outputs_class_unselected.max(-1)[0]
296
+ enc_outputs_coord_unselected = (
297
+ self.enc_out_bbox_embed(output_memory) + output_proposals
298
+ ) # (bs, \sum{hw}, 4) unsigmoid
299
+ topk = self.num_queries
300
+
301
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
302
+
303
+ # gather boxes
304
+ refpoint_embed_undetach = torch.gather(
305
+ enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
306
+ ) # unsigmoid
307
+ refpoint_embed_ = refpoint_embed_undetach.detach()
308
+ init_box_proposal = torch.gather(
309
+ output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
310
+ ).sigmoid() # sigmoid
311
+
312
+ # gather tgt
313
+ tgt_undetach = torch.gather(
314
+ output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
315
+ )
316
+ if self.embed_init_tgt:
317
+ tgt_ = (
318
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
319
+ ) # nq, bs, d_model
320
+ else:
321
+ tgt_ = tgt_undetach.detach()
322
+
323
+ if refpoint_embed is not None:
324
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
325
+ tgt = torch.cat([tgt, tgt_], dim=1)
326
+ else:
327
+ refpoint_embed, tgt = refpoint_embed_, tgt_
328
+
329
+ elif self.two_stage_type == "no":
330
+ tgt_ = (
331
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
332
+ ) # nq, bs, d_model
333
+ refpoint_embed_ = (
334
+ self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
335
+ ) # nq, bs, 4
336
+
337
+ if refpoint_embed is not None:
338
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
339
+ tgt = torch.cat([tgt, tgt_], dim=1)
340
+ else:
341
+ refpoint_embed, tgt = refpoint_embed_, tgt_
342
+
343
+ if self.num_patterns > 0:
344
+ tgt_embed = tgt.repeat(1, self.num_patterns, 1)
345
+ refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
346
+ tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(
347
+ self.num_queries, 1
348
+ ) # 1, n_q*n_pat, d_model
349
+ tgt = tgt_embed + tgt_pat
350
+
351
+ init_box_proposal = refpoint_embed_.sigmoid()
352
+
353
+ else:
354
+ raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
355
+ #########################################################
356
+ # End preparing tgt
357
+ # - tgt: bs, NQ, d_model
358
+ # - refpoint_embed(unsigmoid): bs, NQ, d_model
359
+ #########################################################
360
+
361
+ #########################################################
362
+ # Begin Decoder
363
+ #########################################################
364
+ hs, references = self.decoder(
365
+ tgt=tgt.transpose(0, 1),
366
+ memory=memory.transpose(0, 1),
367
+ memory_key_padding_mask=mask_flatten,
368
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
369
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
370
+ level_start_index=level_start_index,
371
+ spatial_shapes=spatial_shapes,
372
+ valid_ratios=valid_ratios,
373
+ tgt_mask=attn_mask,
374
+ memory_text=text_dict["encoded_text"],
375
+ text_attention_mask=~text_dict["text_token_mask"],
376
+ # we ~ the mask . False means use the token; True means pad the token
377
+ )
378
+ #########################################################
379
+ # End Decoder
380
+ # hs: n_dec, bs, nq, d_model
381
+ # references: n_dec+1, bs, nq, query_dim
382
+ #########################################################
383
+
384
+ #########################################################
385
+ # Begin postprocess
386
+ #########################################################
387
+ if self.two_stage_type == "standard":
388
+ hs_enc = tgt_undetach.unsqueeze(0)
389
+ ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
390
+ else:
391
+ hs_enc = ref_enc = None
392
+ #########################################################
393
+ # End postprocess
394
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
395
+ # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
396
+ #########################################################
397
+
398
+ return hs, references, hs_enc, ref_enc, init_box_proposal
399
+ # hs: (n_dec, bs, nq, d_model)
400
+ # references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
401
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
402
+ # ref_enc: sigmoid coordinates. \
403
+ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
404
+
405
+
406
+ class TransformerEncoder(nn.Module):
407
+ def __init__(
408
+ self,
409
+ encoder_layer,
410
+ num_layers,
411
+ d_model=256,
412
+ num_queries=300,
413
+ enc_layer_share=False,
414
+ text_enhance_layer=None,
415
+ feature_fusion_layer=None,
416
+ use_checkpoint=False,
417
+ use_transformer_ckpt=False,
418
+ ):
419
+ """_summary_
420
+
421
+ Args:
422
+ encoder_layer (_type_): _description_
423
+ num_layers (_type_): _description_
424
+ norm (_type_, optional): _description_. Defaults to None.
425
+ d_model (int, optional): _description_. Defaults to 256.
426
+ num_queries (int, optional): _description_. Defaults to 300.
427
+ enc_layer_share (bool, optional): _description_. Defaults to False.
428
+
429
+ """
430
+ super().__init__()
431
+ # prepare layers
432
+ self.layers = []
433
+ self.text_layers = []
434
+ self.fusion_layers = []
435
+ if num_layers > 0:
436
+ self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
437
+
438
+ if text_enhance_layer is not None:
439
+ self.text_layers = _get_clones(
440
+ text_enhance_layer, num_layers, layer_share=enc_layer_share
441
+ )
442
+ if feature_fusion_layer is not None:
443
+ self.fusion_layers = _get_clones(
444
+ feature_fusion_layer, num_layers, layer_share=enc_layer_share
445
+ )
446
+ else:
447
+ self.layers = []
448
+ del encoder_layer
449
+
450
+ if text_enhance_layer is not None:
451
+ self.text_layers = []
452
+ del text_enhance_layer
453
+ if feature_fusion_layer is not None:
454
+ self.fusion_layers = []
455
+ del feature_fusion_layer
456
+
457
+ self.query_scale = None
458
+ self.num_queries = num_queries
459
+ self.num_layers = num_layers
460
+ self.d_model = d_model
461
+
462
+ self.use_checkpoint = use_checkpoint
463
+ self.use_transformer_ckpt = use_transformer_ckpt
464
+
465
+ @staticmethod
466
+ def get_reference_points(spatial_shapes, valid_ratios, device):
467
+ reference_points_list = []
468
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
469
+
470
+ ref_y, ref_x = torch.meshgrid(
471
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
472
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
473
+ )
474
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
475
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
476
+ ref = torch.stack((ref_x, ref_y), -1)
477
+ reference_points_list.append(ref)
478
+ reference_points = torch.cat(reference_points_list, 1)
479
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
480
+ return reference_points
481
+
482
+ def forward(
483
+ self,
484
+ # for images
485
+ src: Tensor,
486
+ pos: Tensor,
487
+ spatial_shapes: Tensor,
488
+ level_start_index: Tensor,
489
+ valid_ratios: Tensor,
490
+ key_padding_mask: Tensor,
491
+ # for texts
492
+ memory_text: Tensor = None,
493
+ text_attention_mask: Tensor = None,
494
+ pos_text: Tensor = None,
495
+ text_self_attention_masks: Tensor = None,
496
+ position_ids: Tensor = None,
497
+ ):
498
+ """
499
+ Input:
500
+ - src: [bs, sum(hi*wi), 256]
501
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
502
+ - spatial_shapes: h,w of each level [num_level, 2]
503
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
504
+ - valid_ratios: [bs, num_level, 2]
505
+ - key_padding_mask: [bs, sum(hi*wi)]
506
+
507
+ - memory_text: bs, n_text, 256
508
+ - text_attention_mask: bs, n_text
509
+ False for no padding; True for padding
510
+ - pos_text: bs, n_text, 256
511
+
512
+ - position_ids: bs, n_text
513
+ Intermedia:
514
+ - reference_points: [bs, sum(hi*wi), num_level, 2]
515
+ Outpus:
516
+ - output: [bs, sum(hi*wi), 256]
517
+ """
518
+
519
+ output = src
520
+
521
+ # preparation and reshape
522
+ if self.num_layers > 0:
523
+ reference_points = self.get_reference_points(
524
+ spatial_shapes, valid_ratios, device=src.device
525
+ )
526
+
527
+ if self.text_layers:
528
+ # generate pos_text
529
+ bs, n_text, text_dim = memory_text.shape
530
+ if pos_text is None and position_ids is None:
531
+ pos_text = (
532
+ torch.arange(n_text, device=memory_text.device)
533
+ .float()
534
+ .unsqueeze(0)
535
+ .unsqueeze(-1)
536
+ .repeat(bs, 1, 1)
537
+ )
538
+ pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
539
+ if position_ids is not None:
540
+ pos_text = get_sine_pos_embed(
541
+ position_ids[..., None], num_pos_feats=256, exchange_xy=False
542
+ )
543
+
544
+ # main process
545
+ for layer_id, layer in enumerate(self.layers):
546
+ # if output.isnan().any() or memory_text.isnan().any():
547
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
548
+ # import ipdb; ipdb.set_trace()
549
+ if self.fusion_layers:
550
+ if self.use_checkpoint:
551
+ output, memory_text = checkpoint.checkpoint(
552
+ self.fusion_layers[layer_id],
553
+ output,
554
+ memory_text,
555
+ key_padding_mask,
556
+ text_attention_mask,
557
+ )
558
+ else:
559
+ output, memory_text = self.fusion_layers[layer_id](
560
+ v=output,
561
+ l=memory_text,
562
+ attention_mask_v=key_padding_mask,
563
+ attention_mask_l=text_attention_mask,
564
+ )
565
+
566
+ if self.text_layers:
567
+ memory_text = self.text_layers[layer_id](
568
+ src=memory_text.transpose(0, 1),
569
+ src_mask=~text_self_attention_masks, # note we use ~ for mask here
570
+ src_key_padding_mask=text_attention_mask,
571
+ pos=(pos_text.transpose(0, 1) if pos_text is not None else None),
572
+ ).transpose(0, 1)
573
+
574
+ # main process
575
+ if self.use_transformer_ckpt:
576
+ output = checkpoint.checkpoint(
577
+ layer,
578
+ output,
579
+ pos,
580
+ reference_points,
581
+ spatial_shapes,
582
+ level_start_index,
583
+ key_padding_mask,
584
+ )
585
+ else:
586
+ output = layer(
587
+ src=output,
588
+ pos=pos,
589
+ reference_points=reference_points,
590
+ spatial_shapes=spatial_shapes,
591
+ level_start_index=level_start_index,
592
+ key_padding_mask=key_padding_mask,
593
+ )
594
+
595
+ return output, memory_text
596
+
597
+
598
+ class TransformerDecoder(nn.Module):
599
+ def __init__(
600
+ self,
601
+ decoder_layer,
602
+ num_layers,
603
+ norm=None,
604
+ return_intermediate=False,
605
+ d_model=256,
606
+ query_dim=4,
607
+ num_feature_levels=1,
608
+ ):
609
+ super().__init__()
610
+ if num_layers > 0:
611
+ self.layers = _get_clones(decoder_layer, num_layers)
612
+ else:
613
+ self.layers = []
614
+ self.num_layers = num_layers
615
+ self.norm = norm
616
+ self.return_intermediate = return_intermediate
617
+ assert return_intermediate, "support return_intermediate only"
618
+ self.query_dim = query_dim
619
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
620
+ self.num_feature_levels = num_feature_levels
621
+
622
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
623
+ self.query_pos_sine_scale = None
624
+
625
+ self.query_scale = None
626
+ self.bbox_embed = None
627
+ self.class_embed = None
628
+
629
+ self.d_model = d_model
630
+
631
+ self.ref_anchor_head = None
632
+
633
+ def forward(
634
+ self,
635
+ tgt,
636
+ memory,
637
+ tgt_mask: Optional[Tensor] = None,
638
+ memory_mask: Optional[Tensor] = None,
639
+ tgt_key_padding_mask: Optional[Tensor] = None,
640
+ memory_key_padding_mask: Optional[Tensor] = None,
641
+ pos: Optional[Tensor] = None,
642
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
643
+ # for memory
644
+ level_start_index: Optional[Tensor] = None, # num_levels
645
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
646
+ valid_ratios: Optional[Tensor] = None,
647
+ # for text
648
+ memory_text: Optional[Tensor] = None,
649
+ text_attention_mask: Optional[Tensor] = None,
650
+ ):
651
+ """
652
+ Input:
653
+ - tgt: nq, bs, d_model
654
+ - memory: hw, bs, d_model
655
+ - pos: hw, bs, d_model
656
+ - refpoints_unsigmoid: nq, bs, 2/4
657
+ - valid_ratios/spatial_shapes: bs, nlevel, 2
658
+ """
659
+ output = tgt
660
+
661
+ intermediate = []
662
+ reference_points = refpoints_unsigmoid.sigmoid()
663
+ ref_points = [reference_points]
664
+
665
+ for layer_id, layer in enumerate(self.layers):
666
+
667
+ if reference_points.shape[-1] == 4:
668
+ reference_points_input = (
669
+ reference_points[:, :, None]
670
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
671
+ ) # nq, bs, nlevel, 4
672
+ else:
673
+ assert reference_points.shape[-1] == 2
674
+ reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
675
+ query_sine_embed = gen_sineembed_for_position(
676
+ reference_points_input[:, :, 0, :]
677
+ ) # nq, bs, 256*2
678
+
679
+ # conditional query
680
+ raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
681
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
682
+ query_pos = pos_scale * raw_query_pos
683
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
684
+ # if query_pos.isnan().any() | query_pos.isinf().any():
685
+ # import ipdb; ipdb.set_trace()
686
+
687
+ # main process
688
+ output = layer(
689
+ tgt=output,
690
+ tgt_query_pos=query_pos,
691
+ tgt_query_sine_embed=query_sine_embed,
692
+ tgt_key_padding_mask=tgt_key_padding_mask,
693
+ tgt_reference_points=reference_points_input,
694
+ memory_text=memory_text,
695
+ text_attention_mask=text_attention_mask,
696
+ memory=memory,
697
+ memory_key_padding_mask=memory_key_padding_mask,
698
+ memory_level_start_index=level_start_index,
699
+ memory_spatial_shapes=spatial_shapes,
700
+ memory_pos=pos,
701
+ self_attn_mask=tgt_mask,
702
+ cross_attn_mask=memory_mask,
703
+ )
704
+ if output.isnan().any() | output.isinf().any():
705
+ print(f"output layer_id {layer_id} is nan")
706
+ try:
707
+ num_nan = output.isnan().sum().item()
708
+ num_inf = output.isinf().sum().item()
709
+ print(f"num_nan {num_nan}, num_inf {num_inf}")
710
+ except Exception as e:
711
+ print(e)
712
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
713
+ # import ipdb; ipdb.set_trace()
714
+
715
+ # iter update
716
+ if self.bbox_embed is not None:
717
+ # box_holder = self.bbox_embed(output)
718
+ # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
719
+ # new_reference_points = box_holder[..., :self.query_dim].sigmoid()
720
+
721
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
722
+ delta_unsig = self.bbox_embed[layer_id](output)
723
+ outputs_unsig = delta_unsig + reference_before_sigmoid
724
+ new_reference_points = outputs_unsig.sigmoid()
725
+
726
+ reference_points = new_reference_points.detach()
727
+ # if layer_id != self.num_layers - 1:
728
+ ref_points.append(new_reference_points)
729
+
730
+ intermediate.append(self.norm(output))
731
+
732
+ return [
733
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
734
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
735
+ ]
736
+
737
+
738
+ class DeformableTransformerEncoderLayer(nn.Module):
739
+ def __init__(
740
+ self,
741
+ d_model=256,
742
+ d_ffn=1024,
743
+ dropout=0.1,
744
+ activation="relu",
745
+ n_levels=4,
746
+ n_heads=8,
747
+ n_points=4,
748
+ ):
749
+ super().__init__()
750
+
751
+ # self attention
752
+ self.self_attn = MSDeformAttn(
753
+ embed_dim=d_model,
754
+ num_levels=n_levels,
755
+ num_heads=n_heads,
756
+ num_points=n_points,
757
+ batch_first=True,
758
+ )
759
+ self.dropout1 = nn.Dropout(dropout)
760
+ self.norm1 = nn.LayerNorm(d_model)
761
+
762
+ # ffn
763
+ self.linear1 = nn.Linear(d_model, d_ffn)
764
+ self.activation = _get_activation_fn(activation, d_model=d_ffn)
765
+ self.dropout2 = nn.Dropout(dropout)
766
+ self.linear2 = nn.Linear(d_ffn, d_model)
767
+ self.dropout3 = nn.Dropout(dropout)
768
+ self.norm2 = nn.LayerNorm(d_model)
769
+
770
+ @staticmethod
771
+ def with_pos_embed(tensor, pos):
772
+ return tensor if pos is None else tensor + pos
773
+
774
+ def forward_ffn(self, src):
775
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
776
+ src = src + self.dropout3(src2)
777
+ src = self.norm2(src)
778
+ return src
779
+
780
+ def forward(
781
+ self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None
782
+ ):
783
+ # self attention
784
+ # import ipdb; ipdb.set_trace()
785
+ src2 = self.self_attn(
786
+ query=self.with_pos_embed(src, pos),
787
+ reference_points=reference_points,
788
+ value=src,
789
+ spatial_shapes=spatial_shapes,
790
+ level_start_index=level_start_index,
791
+ key_padding_mask=key_padding_mask,
792
+ )
793
+ src = src + self.dropout1(src2)
794
+ src = self.norm1(src)
795
+
796
+ # ffn
797
+ src = self.forward_ffn(src)
798
+
799
+ return src
800
+
801
+
802
+ class DeformableTransformerDecoderLayer(nn.Module):
803
+ def __init__(
804
+ self,
805
+ d_model=256,
806
+ d_ffn=1024,
807
+ dropout=0.1,
808
+ activation="relu",
809
+ n_levels=4,
810
+ n_heads=8,
811
+ n_points=4,
812
+ use_text_feat_guide=False,
813
+ use_text_cross_attention=False,
814
+ ):
815
+ super().__init__()
816
+
817
+ # cross attention
818
+ self.cross_attn = MSDeformAttn(
819
+ embed_dim=d_model,
820
+ num_levels=n_levels,
821
+ num_heads=n_heads,
822
+ num_points=n_points,
823
+ batch_first=True,
824
+ )
825
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
826
+ self.norm1 = nn.LayerNorm(d_model)
827
+
828
+ # cross attention text
829
+ if use_text_cross_attention:
830
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
831
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
832
+ self.catext_norm = nn.LayerNorm(d_model)
833
+
834
+ # self attention
835
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
836
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
837
+ self.norm2 = nn.LayerNorm(d_model)
838
+
839
+ # ffn
840
+ self.linear1 = nn.Linear(d_model, d_ffn)
841
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
842
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
843
+ self.linear2 = nn.Linear(d_ffn, d_model)
844
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
845
+ self.norm3 = nn.LayerNorm(d_model)
846
+
847
+ self.key_aware_proj = None
848
+ self.use_text_feat_guide = use_text_feat_guide
849
+ assert not use_text_feat_guide
850
+ self.use_text_cross_attention = use_text_cross_attention
851
+
852
+ def rm_self_attn_modules(self):
853
+ self.self_attn = None
854
+ self.dropout2 = None
855
+ self.norm2 = None
856
+
857
+ @staticmethod
858
+ def with_pos_embed(tensor, pos):
859
+ return tensor if pos is None else tensor + pos
860
+
861
+ def forward_ffn(self, tgt):
862
+ with torch.cuda.amp.autocast(enabled=False):
863
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
864
+ tgt = tgt + self.dropout4(tgt2)
865
+ tgt = self.norm3(tgt)
866
+ return tgt
867
+
868
+ def forward(
869
+ self,
870
+ # for tgt
871
+ tgt: Optional[Tensor], # nq, bs, d_model
872
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
873
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
874
+ tgt_key_padding_mask: Optional[Tensor] = None,
875
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
876
+ memory_text: Optional[Tensor] = None, # bs, num_token, d_model
877
+ text_attention_mask: Optional[Tensor] = None, # bs, num_token
878
+ # for memory
879
+ memory: Optional[Tensor] = None, # hw, bs, d_model
880
+ memory_key_padding_mask: Optional[Tensor] = None,
881
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
882
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
883
+ memory_pos: Optional[Tensor] = None, # pos for memory
884
+ # sa
885
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
886
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
887
+ ):
888
+ """
889
+ Input:
890
+ - tgt/tgt_query_pos: nq, bs, d_model
891
+ -
892
+ """
893
+ assert cross_attn_mask is None
894
+
895
+ # self attention
896
+ if self.self_attn is not None:
897
+ # import ipdb; ipdb.set_trace()
898
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
899
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
900
+ tgt = tgt + self.dropout2(tgt2)
901
+ tgt = self.norm2(tgt)
902
+
903
+ if self.use_text_cross_attention:
904
+ tgt2 = self.ca_text(
905
+ self.with_pos_embed(tgt, tgt_query_pos),
906
+ memory_text.transpose(0, 1),
907
+ memory_text.transpose(0, 1),
908
+ key_padding_mask=text_attention_mask,
909
+ )[0]
910
+ tgt = tgt + self.catext_dropout(tgt2)
911
+ tgt = self.catext_norm(tgt)
912
+
913
+ tgt2 = self.cross_attn(
914
+ query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
915
+ reference_points=tgt_reference_points.transpose(0, 1).contiguous(),
916
+ value=memory.transpose(0, 1),
917
+ spatial_shapes=memory_spatial_shapes,
918
+ level_start_index=memory_level_start_index,
919
+ key_padding_mask=memory_key_padding_mask,
920
+ ).transpose(0, 1)
921
+ tgt = tgt + self.dropout1(tgt2)
922
+ tgt = self.norm1(tgt)
923
+
924
+ # ffn
925
+ tgt = self.forward_ffn(tgt)
926
+
927
+ return tgt
928
+
929
+
930
+ def build_transformer(args):
931
+ return Transformer(
932
+ d_model=args.hidden_dim,
933
+ dropout=args.dropout,
934
+ nhead=args.nheads,
935
+ num_queries=args.num_queries,
936
+ dim_feedforward=args.dim_feedforward,
937
+ num_encoder_layers=args.enc_layers,
938
+ num_decoder_layers=args.dec_layers,
939
+ normalize_before=args.pre_norm,
940
+ return_intermediate_dec=True,
941
+ query_dim=args.query_dim,
942
+ activation=args.transformer_activation,
943
+ num_patterns=args.num_patterns,
944
+ num_feature_levels=args.num_feature_levels,
945
+ enc_n_points=args.enc_n_points,
946
+ dec_n_points=args.dec_n_points,
947
+ learnable_tgt_init=True,
948
+ # two stage
949
+ two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
950
+ embed_init_tgt=args.embed_init_tgt,
951
+ use_text_enhancer=args.use_text_enhancer,
952
+ use_fusion_layer=args.use_fusion_layer,
953
+ use_checkpoint=args.use_checkpoint,
954
+ use_transformer_ckpt=args.use_transformer_ckpt,
955
+ use_text_cross_attention=args.use_text_cross_attention,
956
+ text_dropout=args.text_dropout,
957
+ fusion_dropout=args.fusion_dropout,
958
+ fusion_droppath=args.fusion_droppath,
959
+ )
GroundingDINO/groundingdino/models/GroundingDINO/transformer_vanilla.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
8
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
9
+ """
10
+ DETR Transformer class.
11
+
12
+ Copy-paste from torch.nn.Transformer with modifications:
13
+ * positional encodings are passed in MHattention
14
+ * extra LN at the end of encoder is removed
15
+ * decoder returns a stack of activations from all decoding layers
16
+ """
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import Tensor, nn
22
+
23
+ from .utils import (
24
+ MLP,
25
+ _get_activation_fn,
26
+ _get_clones,
27
+ gen_encoder_output_proposals,
28
+ gen_sineembed_for_position,
29
+ sigmoid_focal_loss,
30
+ )
31
+
32
+
33
+ class TextTransformer(nn.Module):
34
+ def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
35
+ super().__init__()
36
+ self.num_layers = num_layers
37
+ self.d_model = d_model
38
+ self.nheads = nheads
39
+ self.dim_feedforward = dim_feedforward
40
+ self.norm = None
41
+
42
+ single_encoder_layer = TransformerEncoderLayer(
43
+ d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
44
+ )
45
+ self.layers = _get_clones(single_encoder_layer, num_layers)
46
+
47
+ def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
48
+ """
49
+
50
+ Args:
51
+ text_attention_mask: bs, num_token
52
+ memory_text: bs, num_token, d_model
53
+
54
+ Raises:
55
+ RuntimeError: _description_
56
+
57
+ Returns:
58
+ output: bs, num_token, d_model
59
+ """
60
+
61
+ output = memory_text.transpose(0, 1)
62
+
63
+ for layer in self.layers:
64
+ output = layer(output, src_key_padding_mask=text_attention_mask)
65
+
66
+ if self.norm is not None:
67
+ output = self.norm(output)
68
+
69
+ return output.transpose(0, 1)
70
+
71
+
72
+ class TransformerEncoderLayer(nn.Module):
73
+ def __init__(
74
+ self,
75
+ d_model,
76
+ nhead,
77
+ dim_feedforward=2048,
78
+ dropout=0.1,
79
+ activation="relu",
80
+ normalize_before=False,
81
+ ):
82
+ super().__init__()
83
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
84
+ # Implementation of Feedforward model
85
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
86
+ self.dropout = nn.Dropout(dropout)
87
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
88
+
89
+ self.norm1 = nn.LayerNorm(d_model)
90
+ self.norm2 = nn.LayerNorm(d_model)
91
+ self.dropout1 = nn.Dropout(dropout)
92
+ self.dropout2 = nn.Dropout(dropout)
93
+
94
+ self.activation = _get_activation_fn(activation)
95
+ self.normalize_before = normalize_before
96
+ self.nhead = nhead
97
+
98
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
99
+ return tensor if pos is None else tensor + pos
100
+
101
+ def forward(
102
+ self,
103
+ src,
104
+ src_mask: Optional[Tensor] = None,
105
+ src_key_padding_mask: Optional[Tensor] = None,
106
+ pos: Optional[Tensor] = None,
107
+ ):
108
+ # repeat attn mask
109
+ if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
110
+ # bs, num_q, num_k
111
+ src_mask = src_mask.repeat(self.nhead, 1, 1)
112
+
113
+ q = k = self.with_pos_embed(src, pos)
114
+
115
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
116
+
117
+ # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
118
+ src = src + self.dropout1(src2)
119
+ src = self.norm1(src)
120
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
121
+ src = src + self.dropout2(src2)
122
+ src = self.norm2(src)
123
+ return src
GroundingDINO/groundingdino/models/GroundingDINO/utils.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+
8
+ import copy
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
+
15
+
16
+ def _get_clones(module, N, layer_share=False):
17
+ # import ipdb; ipdb.set_trace()
18
+ if layer_share:
19
+ return nn.ModuleList([module for i in range(N)])
20
+ else:
21
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
22
+
23
+
24
+ def get_sine_pos_embed(
25
+ pos_tensor: torch.Tensor,
26
+ num_pos_feats: int = 128,
27
+ temperature: int = 10000,
28
+ exchange_xy: bool = True,
29
+ ):
30
+ """generate sine position embedding from a position tensor
31
+ Args:
32
+ pos_tensor (torch.Tensor): shape: [..., n].
33
+ num_pos_feats (int): projected shape for each float in the tensor.
34
+ temperature (int): temperature in the sine/cosine function.
35
+ exchange_xy (bool, optional): exchange pos x and pos y. \
36
+ For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
37
+ Returns:
38
+ pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
39
+ """
40
+ scale = 2 * math.pi
41
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
42
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
43
+
44
+ def sine_func(x: torch.Tensor):
45
+ sin_x = x * scale / dim_t
46
+ sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3).flatten(2)
47
+ return sin_x
48
+
49
+ pos_res = [sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)]
50
+ if exchange_xy:
51
+ pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
52
+ pos_res = torch.cat(pos_res, dim=-1)
53
+ return pos_res
54
+
55
+
56
+ def gen_encoder_output_proposals(
57
+ memory: Tensor, memory_padding_mask: Tensor, spatial_shapes: Tensor, learnedwh=None
58
+ ):
59
+ """
60
+ Input:
61
+ - memory: bs, \sum{hw}, d_model
62
+ - memory_padding_mask: bs, \sum{hw}
63
+ - spatial_shapes: nlevel, 2
64
+ - learnedwh: 2
65
+ Output:
66
+ - output_memory: bs, \sum{hw}, d_model
67
+ - output_proposals: bs, \sum{hw}, 4
68
+ """
69
+ N_, S_, C_ = memory.shape
70
+ proposals = []
71
+ _cur = 0
72
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
73
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(N_, H_, W_, 1)
74
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
75
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
76
+
77
+ # import ipdb; ipdb.set_trace()
78
+
79
+ grid_y, grid_x = torch.meshgrid(
80
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
81
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
82
+ )
83
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
84
+
85
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
86
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
87
+
88
+ if learnedwh is not None:
89
+ # import ipdb; ipdb.set_trace()
90
+ wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
91
+ else:
92
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
93
+
94
+ # scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
95
+ # grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
96
+ # wh = torch.ones_like(grid) / scale
97
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
98
+ proposals.append(proposal)
99
+ _cur += H_ * W_
100
+ # import ipdb; ipdb.set_trace()
101
+ output_proposals = torch.cat(proposals, 1)
102
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
103
+ -1, keepdim=True
104
+ )
105
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
106
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
107
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
108
+
109
+ output_memory = memory
110
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
111
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
112
+
113
+ # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
114
+ # output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
115
+
116
+ return output_memory, output_proposals
117
+
118
+
119
+ class RandomBoxPerturber:
120
+ def __init__(
121
+ self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2
122
+ ) -> None:
123
+ self.noise_scale = torch.Tensor(
124
+ [x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]
125
+ )
126
+
127
+ def __call__(self, refanchors: Tensor) -> Tensor:
128
+ nq, bs, query_dim = refanchors.shape
129
+ device = refanchors.device
130
+
131
+ noise_raw = torch.rand_like(refanchors)
132
+ noise_scale = self.noise_scale.to(device)[:query_dim]
133
+
134
+ new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
135
+ return new_refanchors.clamp_(0, 1)
136
+
137
+
138
+ def sigmoid_focal_loss(
139
+ inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, no_reduction=False
140
+ ):
141
+ """
142
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
143
+ Args:
144
+ inputs: A float tensor of arbitrary shape.
145
+ The predictions for each example.
146
+ targets: A float tensor with the same shape as inputs. Stores the binary
147
+ classification label for each element in inputs
148
+ (0 for the negative class and 1 for the positive class).
149
+ alpha: (optional) Weighting factor in range (0,1) to balance
150
+ positive vs negative examples. Default = -1 (no weighting).
151
+ gamma: Exponent of the modulating factor (1 - p_t) to
152
+ balance easy vs hard examples.
153
+ Returns:
154
+ Loss tensor
155
+ """
156
+ prob = inputs.sigmoid()
157
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
158
+ p_t = prob * targets + (1 - prob) * (1 - targets)
159
+ loss = ce_loss * ((1 - p_t) ** gamma)
160
+
161
+ if alpha >= 0:
162
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
163
+ loss = alpha_t * loss
164
+
165
+ if no_reduction:
166
+ return loss
167
+
168
+ return loss.mean(1).sum() / num_boxes
169
+
170
+
171
+ class MLP(nn.Module):
172
+ """Very simple multi-layer perceptron (also called FFN)"""
173
+
174
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
175
+ super().__init__()
176
+ self.num_layers = num_layers
177
+ h = [hidden_dim] * (num_layers - 1)
178
+ self.layers = nn.ModuleList(
179
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
180
+ )
181
+
182
+ def forward(self, x):
183
+ for i, layer in enumerate(self.layers):
184
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
185
+ return x
186
+
187
+
188
+ def _get_activation_fn(activation, d_model=256, batch_dim=0):
189
+ """Return an activation function given a string"""
190
+ if activation == "relu":
191
+ return F.relu
192
+ if activation == "gelu":
193
+ return F.gelu
194
+ if activation == "glu":
195
+ return F.glu
196
+ if activation == "prelu":
197
+ return nn.PReLU()
198
+ if activation == "selu":
199
+ return F.selu
200
+
201
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
202
+
203
+
204
+ def gen_sineembed_for_position(pos_tensor):
205
+ # n_query, bs, _ = pos_tensor.size()
206
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
207
+ scale = 2 * math.pi
208
+ dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
209
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
210
+ x_embed = pos_tensor[:, :, 0] * scale
211
+ y_embed = pos_tensor[:, :, 1] * scale
212
+ pos_x = x_embed[:, :, None] / dim_t
213
+ pos_y = y_embed[:, :, None] / dim_t
214
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
215
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
216
+ if pos_tensor.size(-1) == 2:
217
+ pos = torch.cat((pos_y, pos_x), dim=2)
218
+ elif pos_tensor.size(-1) == 4:
219
+ w_embed = pos_tensor[:, :, 2] * scale
220
+ pos_w = w_embed[:, :, None] / dim_t
221
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
222
+
223
+ h_embed = pos_tensor[:, :, 3] * scale
224
+ pos_h = h_embed[:, :, None] / dim_t
225
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
226
+
227
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
228
+ else:
229
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
230
+ return pos
231
+
232
+
233
+ class ContrastiveEmbed(nn.Module):
234
+ def __init__(self, max_text_len=256):
235
+ """
236
+ Args:
237
+ max_text_len: max length of text.
238
+ """
239
+ super().__init__()
240
+ self.max_text_len = max_text_len
241
+
242
+ def forward(self, x, text_dict):
243
+ """_summary_
244
+
245
+ Args:
246
+ x (_type_): _description_
247
+ text_dict (_type_): _description_
248
+ {
249
+ 'encoded_text': encoded_text, # bs, 195, d_model
250
+ 'text_token_mask': text_token_mask, # bs, 195
251
+ # True for used tokens. False for padding tokens
252
+ }
253
+ Returns:
254
+ _type_: _description_
255
+ """
256
+ assert isinstance(text_dict, dict)
257
+
258
+ y = text_dict["encoded_text"]
259
+ text_token_mask = text_dict["text_token_mask"]
260
+
261
+ res = x @ y.transpose(-1, -2)
262
+ res.masked_fill_(~text_token_mask[:, None, :], float("-inf"))
263
+
264
+ # padding to max_text_len
265
+ new_res = torch.full((*res.shape[:-1], self.max_text_len), float("-inf"), device=res.device)
266
+ new_res[..., : res.shape[-1]] = res
267
+
268
+ return new_res
GroundingDINO/groundingdino/models/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ from .GroundingDINO import build_groundingdino
9
+
10
+
11
+ def build_model(args):
12
+ # we use register to maintain models from catdet6 on.
13
+ from .registry import MODULE_BUILD_FUNCS
14
+
15
+ assert args.modelname in MODULE_BUILD_FUNCS._module_dict
16
+ build_func = MODULE_BUILD_FUNCS.get(args.modelname)
17
+ model = build_func(args)
18
+ return model
GroundingDINO/groundingdino/models/registry.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Grounding DINO
3
+ # url: https://github.com/IDEA-Research/GroundingDINO
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # -*- coding: utf-8 -*-
8
+ # @Author: Yihao Chen
9
+ # @Date: 2021-08-16 16:03:17
10
+ # @Last Modified by: Shilong Liu
11
+ # @Last Modified time: 2022-01-23 15:26
12
+ # modified from mmcv
13
+
14
+ import inspect
15
+ from functools import partial
16
+
17
+
18
+ class Registry(object):
19
+ def __init__(self, name):
20
+ self._name = name
21
+ self._module_dict = dict()
22
+
23
+ def __repr__(self):
24
+ format_str = self.__class__.__name__ + "(name={}, items={})".format(
25
+ self._name, list(self._module_dict.keys())
26
+ )
27
+ return format_str
28
+
29
+ def __len__(self):
30
+ return len(self._module_dict)
31
+
32
+ @property
33
+ def name(self):
34
+ return self._name
35
+
36
+ @property
37
+ def module_dict(self):
38
+ return self._module_dict
39
+
40
+ def get(self, key):
41
+ return self._module_dict.get(key, None)
42
+
43
+ def registe_with_name(self, module_name=None, force=False):
44
+ return partial(self.register, module_name=module_name, force=force)
45
+
46
+ def register(self, module_build_function, module_name=None, force=False):
47
+ """Register a module build function.
48
+ Args:
49
+ module (:obj:`nn.Module`): Module to be registered.
50
+ """
51
+ if not inspect.isfunction(module_build_function):
52
+ raise TypeError(
53
+ "module_build_function must be a function, but got {}".format(
54
+ type(module_build_function)
55
+ )
56
+ )
57
+ if module_name is None:
58
+ module_name = module_build_function.__name__
59
+ if not force and module_name in self._module_dict:
60
+ raise KeyError("{} is already registered in {}".format(module_name, self.name))
61
+ self._module_dict[module_name] = module_build_function
62
+
63
+ return module_build_function
64
+
65
+
66
+ MODULE_BUILD_FUNCS = Registry("model build functions")
GroundingDINO/groundingdino/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
GroundingDINO/groundingdino/util/box_ops.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+ import torch
6
+ from torchvision.ops.boxes import box_area
7
+
8
+
9
+ def box_cxcywh_to_xyxy(x):
10
+ x_c, y_c, w, h = x.unbind(-1)
11
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
12
+ return torch.stack(b, dim=-1)
13
+
14
+
15
+ def box_xyxy_to_cxcywh(x):
16
+ x0, y0, x1, y1 = x.unbind(-1)
17
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
18
+ return torch.stack(b, dim=-1)
19
+
20
+
21
+ # modified from torchvision to also return the union
22
+ def box_iou(boxes1, boxes2):
23
+ area1 = box_area(boxes1)
24
+ area2 = box_area(boxes2)
25
+
26
+ # import ipdb; ipdb.set_trace()
27
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
28
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
29
+
30
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
31
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
32
+
33
+ union = area1[:, None] + area2 - inter
34
+
35
+ iou = inter / (union + 1e-6)
36
+ return iou, union
37
+
38
+
39
+ def generalized_box_iou(boxes1, boxes2):
40
+ """
41
+ Generalized IoU from https://giou.stanford.edu/
42
+
43
+ The boxes should be in [x0, y0, x1, y1] format
44
+
45
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
46
+ and M = len(boxes2)
47
+ """
48
+ # degenerate boxes gives inf / nan results
49
+ # so do an early check
50
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
51
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
52
+ # except:
53
+ # import ipdb; ipdb.set_trace()
54
+ iou, union = box_iou(boxes1, boxes2)
55
+
56
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
57
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
58
+
59
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
60
+ area = wh[:, :, 0] * wh[:, :, 1]
61
+
62
+ return iou - (area - union) / (area + 1e-6)
63
+
64
+
65
+ # modified from torchvision to also return the union
66
+ def box_iou_pairwise(boxes1, boxes2):
67
+ area1 = box_area(boxes1)
68
+ area2 = box_area(boxes2)
69
+
70
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
71
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
72
+
73
+ wh = (rb - lt).clamp(min=0) # [N,2]
74
+ inter = wh[:, 0] * wh[:, 1] # [N]
75
+
76
+ union = area1 + area2 - inter
77
+
78
+ iou = inter / union
79
+ return iou, union
80
+
81
+
82
+ def generalized_box_iou_pairwise(boxes1, boxes2):
83
+ """
84
+ Generalized IoU from https://giou.stanford.edu/
85
+
86
+ Input:
87
+ - boxes1, boxes2: N,4
88
+ Output:
89
+ - giou: N, 4
90
+ """
91
+ # degenerate boxes gives inf / nan results
92
+ # so do an early check
93
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
94
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
95
+ assert boxes1.shape == boxes2.shape
96
+ iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
97
+
98
+ lt = torch.min(boxes1[:, :2], boxes2[:, :2])
99
+ rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
100
+
101
+ wh = (rb - lt).clamp(min=0) # [N,2]
102
+ area = wh[:, 0] * wh[:, 1]
103
+
104
+ return iou - (area - union) / area
105
+
106
+
107
+ def masks_to_boxes(masks):
108
+ """Compute the bounding boxes around the provided masks
109
+
110
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
111
+
112
+ Returns a [N, 4] tensors, with the boxes in xyxy format
113
+ """
114
+ if masks.numel() == 0:
115
+ return torch.zeros((0, 4), device=masks.device)
116
+
117
+ h, w = masks.shape[-2:]
118
+
119
+ y = torch.arange(0, h, dtype=torch.float)
120
+ x = torch.arange(0, w, dtype=torch.float)
121
+ y, x = torch.meshgrid(y, x)
122
+
123
+ x_mask = masks * x.unsqueeze(0)
124
+ x_max = x_mask.flatten(1).max(-1)[0]
125
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
126
+
127
+ y_mask = masks * y.unsqueeze(0)
128
+ y_max = y_mask.flatten(1).max(-1)[0]
129
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
130
+
131
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ x = torch.rand(5, 4)
136
+ y = torch.rand(3, 4)
137
+ iou, union = box_iou(x, y)
138
+ import ipdb
139
+
140
+ ipdb.set_trace()
GroundingDINO/groundingdino/util/get_tokenlizer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
2
+
3
+
4
+ def get_tokenlizer(text_encoder_type):
5
+ if not isinstance(text_encoder_type, str):
6
+ # print("text_encoder_type is not a str")
7
+ if hasattr(text_encoder_type, "text_encoder_type"):
8
+ text_encoder_type = text_encoder_type.text_encoder_type
9
+ elif text_encoder_type.get("text_encoder_type", False):
10
+ text_encoder_type = text_encoder_type.get("text_encoder_type")
11
+ else:
12
+ raise ValueError(
13
+ "Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
14
+ )
15
+ print("final text_encoder_type: {}".format(text_encoder_type))
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
18
+ return tokenizer
19
+
20
+
21
+ def get_pretrained_language_model(text_encoder_type):
22
+ if text_encoder_type == "bert-base-uncased":
23
+ return BertModel.from_pretrained(text_encoder_type)
24
+ if text_encoder_type == "roberta-base":
25
+ return RobertaModel.from_pretrained(text_encoder_type)
26
+ raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
GroundingDINO/groundingdino/util/inference.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ import re
4
+ import cv2
5
+ import numpy as np
6
+ import supervision as sv
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.ops import box_convert
10
+
11
+ import groundingdino.datasets.transforms as T
12
+ from groundingdino.models import build_model
13
+ from groundingdino.util.misc import clean_state_dict
14
+ from groundingdino.util.slconfig import SLConfig
15
+ from groundingdino.util.utils import get_phrases_from_posmap
16
+
17
+ # ----------------------------------------------------------------------------------------------------------------------
18
+ # OLD API
19
+ # ----------------------------------------------------------------------------------------------------------------------
20
+
21
+
22
+ def preprocess_caption(caption: str) -> str:
23
+ result = caption.lower().strip()
24
+ if result.endswith("."):
25
+ return result
26
+ return result + "."
27
+
28
+
29
+ def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
30
+ args = SLConfig.fromfile(model_config_path)
31
+ args.device = device
32
+ model = build_model(args)
33
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
34
+ model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
35
+ model.eval()
36
+ return model
37
+
38
+
39
+ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
40
+ transform = T.Compose(
41
+ [
42
+ T.RandomResize([800], max_size=1333),
43
+ T.ToTensor(),
44
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
45
+ ]
46
+ )
47
+ image_source = Image.open(image_path).convert("RGB")
48
+ image = np.asarray(image_source)
49
+ image_transformed, _ = transform(image_source, None)
50
+ return image, image_transformed
51
+
52
+
53
+ def predict(
54
+ model,
55
+ image: torch.Tensor,
56
+ caption: str,
57
+ box_threshold: float,
58
+ text_threshold: float,
59
+ device: str = "cuda"
60
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
61
+ caption = preprocess_caption(caption=caption)
62
+
63
+ model = model.to(device)
64
+ image = image.to(device)
65
+
66
+ with torch.no_grad():
67
+ outputs = model(image[None], captions=[caption])
68
+
69
+ prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
70
+ prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
71
+
72
+ mask = prediction_logits.max(dim=1)[0] > box_threshold
73
+ logits = prediction_logits[mask] # logits.shape = (n, 256)
74
+ boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
75
+
76
+ tokenizer = model.tokenizer
77
+ tokenized = tokenizer(caption)
78
+
79
+ phrases = [
80
+ get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
81
+ for logit
82
+ in logits
83
+ ]
84
+
85
+ return boxes, logits.max(dim=1)[0], phrases
86
+
87
+
88
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
89
+ h, w, _ = image_source.shape
90
+ boxes = boxes * torch.Tensor([w, h, w, h])
91
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
92
+ detections = sv.Detections(xyxy=xyxy)
93
+
94
+ labels = [
95
+ f"{phrase} {logit:.2f}"
96
+ for phrase, logit
97
+ in zip(phrases, logits)
98
+ ]
99
+
100
+ box_annotator = sv.BoxAnnotator()
101
+ annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
102
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
103
+ return annotated_frame
104
+
105
+
106
+ # ----------------------------------------------------------------------------------------------------------------------
107
+ # NEW API
108
+ # ----------------------------------------------------------------------------------------------------------------------
109
+
110
+
111
+ class Model:
112
+
113
+ def __init__(
114
+ self,
115
+ model_config_path: str,
116
+ model_checkpoint_path: str,
117
+ device: str = "cuda"
118
+ ):
119
+ self.model = load_model(
120
+ model_config_path=model_config_path,
121
+ model_checkpoint_path=model_checkpoint_path,
122
+ device=device
123
+ ).to(device)
124
+ self.device = device
125
+
126
+ def predict_with_caption(
127
+ self,
128
+ image: np.ndarray,
129
+ caption: str,
130
+ box_threshold: float = 0.35,
131
+ text_threshold: float = 0.25
132
+ ) -> Tuple[sv.Detections, List[str]]:
133
+ """
134
+ import cv2
135
+
136
+ image = cv2.imread(IMAGE_PATH)
137
+
138
+ model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
139
+ detections, labels = model.predict_with_caption(
140
+ image=image,
141
+ caption=caption,
142
+ box_threshold=BOX_THRESHOLD,
143
+ text_threshold=TEXT_THRESHOLD
144
+ )
145
+
146
+ import supervision as sv
147
+
148
+ box_annotator = sv.BoxAnnotator()
149
+ annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
150
+ """
151
+ processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
152
+ boxes, logits, phrases = predict(
153
+ model=self.model,
154
+ image=processed_image,
155
+ caption=caption,
156
+ box_threshold=box_threshold,
157
+ text_threshold=text_threshold,
158
+ device=self.device)
159
+ source_h, source_w, _ = image.shape
160
+ detections = Model.post_process_result(
161
+ source_h=source_h,
162
+ source_w=source_w,
163
+ boxes=boxes,
164
+ logits=logits)
165
+ return detections, phrases
166
+
167
+ def predict_with_classes(
168
+ self,
169
+ image: np.ndarray,
170
+ classes: List[str],
171
+ box_threshold: float,
172
+ text_threshold: float
173
+ ) -> sv.Detections:
174
+ """
175
+ import cv2
176
+
177
+ image = cv2.imread(IMAGE_PATH)
178
+
179
+ model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
180
+ detections = model.predict_with_classes(
181
+ image=image,
182
+ classes=CLASSES,
183
+ box_threshold=BOX_THRESHOLD,
184
+ text_threshold=TEXT_THRESHOLD
185
+ )
186
+
187
+
188
+ import supervision as sv
189
+
190
+ box_annotator = sv.BoxAnnotator()
191
+ annotated_image = box_annotator.annotate(scene=image, detections=detections)
192
+ """
193
+ caption = ". ".join(classes)
194
+ processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
195
+ boxes, logits, phrases = predict(
196
+ model=self.model,
197
+ image=processed_image,
198
+ caption=caption,
199
+ box_threshold=box_threshold,
200
+ text_threshold=text_threshold,
201
+ device=self.device)
202
+ source_h, source_w, _ = image.shape
203
+ detections = Model.post_process_result(
204
+ source_h=source_h,
205
+ source_w=source_w,
206
+ boxes=boxes,
207
+ logits=logits)
208
+ class_id = Model.phrases2classes(phrases=phrases, classes=classes)
209
+ detections.class_id = class_id
210
+ return detections
211
+
212
+ @staticmethod
213
+ def preprocess_image(image_bgr: np.ndarray) -> torch.Tensor:
214
+ transform = T.Compose(
215
+ [
216
+ T.RandomResize([800], max_size=1333),
217
+ T.ToTensor(),
218
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
219
+ ]
220
+ )
221
+ image_pillow = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
222
+ image_transformed, _ = transform(image_pillow, None)
223
+ return image_transformed
224
+
225
+ @staticmethod
226
+ def post_process_result(
227
+ source_h: int,
228
+ source_w: int,
229
+ boxes: torch.Tensor,
230
+ logits: torch.Tensor
231
+ ) -> sv.Detections:
232
+ boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
233
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
234
+ confidence = logits.numpy()
235
+ return sv.Detections(xyxy=xyxy, confidence=confidence)
236
+
237
+ @staticmethod
238
+ def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray:
239
+ class_ids = []
240
+ for phrase in phrases:
241
+ try:
242
+ # class_ids.append(classes.index(phrase))
243
+ class_ids.append(Model.find_index(phrase, classes))
244
+ except ValueError:
245
+ class_ids.append(None)
246
+ return np.array(class_ids)
247
+
248
+ @staticmethod
249
+ def find_index(string, lst):
250
+ # if meet string like "lake river" will only keep "lake"
251
+ # this is an hack implementation for visualization which will be updated in the future
252
+ string = string.lower().split()[0]
253
+ for i, s in enumerate(lst):
254
+ if string in s.lower():
255
+ return i
256
+ print("There's a wrong phrase happen, this is because of our post-process merged wrong tokens, which will be modified in the future. We will assign it with a random label at this time.")
257
+ return 0
GroundingDINO/groundingdino/util/logger.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import functools
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ from termcolor import colored
8
+
9
+
10
+ class _ColorfulFormatter(logging.Formatter):
11
+ def __init__(self, *args, **kwargs):
12
+ self._root_name = kwargs.pop("root_name") + "."
13
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
14
+ if len(self._abbrev_name):
15
+ self._abbrev_name = self._abbrev_name + "."
16
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
17
+
18
+ def formatMessage(self, record):
19
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
20
+ log = super(_ColorfulFormatter, self).formatMessage(record)
21
+ if record.levelno == logging.WARNING:
22
+ prefix = colored("WARNING", "red", attrs=["blink"])
23
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
24
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
25
+ else:
26
+ return log
27
+ return prefix + " " + log
28
+
29
+
30
+ # so that calling setup_logger multiple times won't add many handlers
31
+ @functools.lru_cache()
32
+ def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
33
+ """
34
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
35
+
36
+ Args:
37
+ output (str): a file name or a directory to save log. If None, will not save log file.
38
+ If ends with ".txt" or ".log", assumed to be a file name.
39
+ Otherwise, logs will be saved to `output/log.txt`.
40
+ name (str): the root module name of this logger
41
+
42
+ Returns:
43
+ logging.Logger: a logger
44
+ """
45
+ logger = logging.getLogger(name)
46
+ logger.setLevel(logging.DEBUG)
47
+ logger.propagate = False
48
+
49
+ if abbrev_name is None:
50
+ abbrev_name = name
51
+
52
+ plain_formatter = logging.Formatter(
53
+ "[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S"
54
+ )
55
+ # stdout logging: master only
56
+ if distributed_rank == 0:
57
+ ch = logging.StreamHandler(stream=sys.stdout)
58
+ ch.setLevel(logging.DEBUG)
59
+ if color:
60
+ formatter = _ColorfulFormatter(
61
+ colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
62
+ datefmt="%m/%d %H:%M:%S",
63
+ root_name=name,
64
+ abbrev_name=str(abbrev_name),
65
+ )
66
+ else:
67
+ formatter = plain_formatter
68
+ ch.setFormatter(formatter)
69
+ logger.addHandler(ch)
70
+
71
+ # file logging: all workers
72
+ if output is not None:
73
+ if output.endswith(".txt") or output.endswith(".log"):
74
+ filename = output
75
+ else:
76
+ filename = os.path.join(output, "log.txt")
77
+ if distributed_rank > 0:
78
+ filename = filename + f".rank{distributed_rank}"
79
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
80
+
81
+ fh = logging.StreamHandler(_cached_log_stream(filename))
82
+ fh.setLevel(logging.DEBUG)
83
+ fh.setFormatter(plain_formatter)
84
+ logger.addHandler(fh)
85
+
86
+ return logger
87
+
88
+
89
+ # cache the opened file object, so that different calls to `setup_logger`
90
+ # with the same file name can safely write to the same file.
91
+ @functools.lru_cache(maxsize=None)
92
+ def _cached_log_stream(filename):
93
+ return open(filename, "a")
GroundingDINO/groundingdino/util/misc.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import colorsys
8
+ import datetime
9
+ import functools
10
+ import io
11
+ import json
12
+ import os
13
+ import pickle
14
+ import subprocess
15
+ import time
16
+ from collections import OrderedDict, defaultdict, deque
17
+ from typing import List, Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
24
+ import torchvision
25
+ from torch import Tensor
26
+
27
+ __torchvision_need_compat_flag = float(torchvision.__version__.split(".")[1]) < 7
28
+ if __torchvision_need_compat_flag:
29
+ from torchvision.ops import _new_empty_tensor
30
+ from torchvision.ops.misc import _output_size
31
+
32
+
33
+ class SmoothedValue(object):
34
+ """Track a series of values and provide access to smoothed values over a
35
+ window or the global series average.
36
+ """
37
+
38
+ def __init__(self, window_size=20, fmt=None):
39
+ if fmt is None:
40
+ fmt = "{median:.4f} ({global_avg:.4f})"
41
+ self.deque = deque(maxlen=window_size)
42
+ self.total = 0.0
43
+ self.count = 0
44
+ self.fmt = fmt
45
+
46
+ def update(self, value, n=1):
47
+ self.deque.append(value)
48
+ self.count += n
49
+ self.total += value * n
50
+
51
+ def synchronize_between_processes(self):
52
+ """
53
+ Warning: does not synchronize the deque!
54
+ """
55
+ if not is_dist_avail_and_initialized():
56
+ return
57
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
58
+ dist.barrier()
59
+ dist.all_reduce(t)
60
+ t = t.tolist()
61
+ self.count = int(t[0])
62
+ self.total = t[1]
63
+
64
+ @property
65
+ def median(self):
66
+ d = torch.tensor(list(self.deque))
67
+ if d.shape[0] == 0:
68
+ return 0
69
+ return d.median().item()
70
+
71
+ @property
72
+ def avg(self):
73
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
74
+ return d.mean().item()
75
+
76
+ @property
77
+ def global_avg(self):
78
+ if os.environ.get("SHILONG_AMP", None) == "1":
79
+ eps = 1e-4
80
+ else:
81
+ eps = 1e-6
82
+ return self.total / (self.count + eps)
83
+
84
+ @property
85
+ def max(self):
86
+ return max(self.deque)
87
+
88
+ @property
89
+ def value(self):
90
+ return self.deque[-1]
91
+
92
+ def __str__(self):
93
+ return self.fmt.format(
94
+ median=self.median,
95
+ avg=self.avg,
96
+ global_avg=self.global_avg,
97
+ max=self.max,
98
+ value=self.value,
99
+ )
100
+
101
+
102
+ @functools.lru_cache()
103
+ def _get_global_gloo_group():
104
+ """
105
+ Return a process group based on gloo backend, containing all the ranks
106
+ The result is cached.
107
+ """
108
+
109
+ if dist.get_backend() == "nccl":
110
+ return dist.new_group(backend="gloo")
111
+
112
+ return dist.group.WORLD
113
+
114
+
115
+ def all_gather_cpu(data):
116
+ """
117
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
118
+ Args:
119
+ data: any picklable object
120
+ Returns:
121
+ list[data]: list of data gathered from each rank
122
+ """
123
+
124
+ world_size = get_world_size()
125
+ if world_size == 1:
126
+ return [data]
127
+
128
+ cpu_group = _get_global_gloo_group()
129
+
130
+ buffer = io.BytesIO()
131
+ torch.save(data, buffer)
132
+ data_view = buffer.getbuffer()
133
+ device = "cuda" if cpu_group is None else "cpu"
134
+ tensor = torch.ByteTensor(data_view).to(device)
135
+
136
+ # obtain Tensor size of each rank
137
+ local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
138
+ size_list = [torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)]
139
+ if cpu_group is None:
140
+ dist.all_gather(size_list, local_size)
141
+ else:
142
+ print("gathering on cpu")
143
+ dist.all_gather(size_list, local_size, group=cpu_group)
144
+ size_list = [int(size.item()) for size in size_list]
145
+ max_size = max(size_list)
146
+ assert isinstance(local_size.item(), int)
147
+ local_size = int(local_size.item())
148
+
149
+ # receiving Tensor from all ranks
150
+ # we pad the tensor because torch all_gather does not support
151
+ # gathering tensors of different shapes
152
+ tensor_list = []
153
+ for _ in size_list:
154
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
155
+ if local_size != max_size:
156
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device=device)
157
+ tensor = torch.cat((tensor, padding), dim=0)
158
+ if cpu_group is None:
159
+ dist.all_gather(tensor_list, tensor)
160
+ else:
161
+ dist.all_gather(tensor_list, tensor, group=cpu_group)
162
+
163
+ data_list = []
164
+ for size, tensor in zip(size_list, tensor_list):
165
+ tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
166
+ buffer = io.BytesIO(tensor.cpu().numpy())
167
+ obj = torch.load(buffer)
168
+ data_list.append(obj)
169
+
170
+ return data_list
171
+
172
+
173
+ def all_gather(data):
174
+ """
175
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
176
+ Args:
177
+ data: any picklable object
178
+ Returns:
179
+ list[data]: list of data gathered from each rank
180
+ """
181
+
182
+ if os.getenv("CPU_REDUCE") == "1":
183
+ return all_gather_cpu(data)
184
+
185
+ world_size = get_world_size()
186
+ if world_size == 1:
187
+ return [data]
188
+
189
+ # serialized to a Tensor
190
+ buffer = pickle.dumps(data)
191
+ storage = torch.ByteStorage.from_buffer(buffer)
192
+ tensor = torch.ByteTensor(storage).to("cuda")
193
+
194
+ # obtain Tensor size of each rank
195
+ local_size = torch.tensor([tensor.numel()], device="cuda")
196
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
197
+ dist.all_gather(size_list, local_size)
198
+ size_list = [int(size.item()) for size in size_list]
199
+ max_size = max(size_list)
200
+
201
+ # receiving Tensor from all ranks
202
+ # we pad the tensor because torch all_gather does not support
203
+ # gathering tensors of different shapes
204
+ tensor_list = []
205
+ for _ in size_list:
206
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
207
+ if local_size != max_size:
208
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
209
+ tensor = torch.cat((tensor, padding), dim=0)
210
+ dist.all_gather(tensor_list, tensor)
211
+
212
+ data_list = []
213
+ for size, tensor in zip(size_list, tensor_list):
214
+ buffer = tensor.cpu().numpy().tobytes()[:size]
215
+ data_list.append(pickle.loads(buffer))
216
+
217
+ return data_list
218
+
219
+
220
+ def reduce_dict(input_dict, average=True):
221
+ """
222
+ Args:
223
+ input_dict (dict): all the values will be reduced
224
+ average (bool): whether to do average or sum
225
+ Reduce the values in the dictionary from all processes so that all processes
226
+ have the averaged results. Returns a dict with the same fields as
227
+ input_dict, after reduction.
228
+ """
229
+ world_size = get_world_size()
230
+ if world_size < 2:
231
+ return input_dict
232
+ with torch.no_grad():
233
+ names = []
234
+ values = []
235
+ # sort the keys so that they are consistent across processes
236
+ for k in sorted(input_dict.keys()):
237
+ names.append(k)
238
+ values.append(input_dict[k])
239
+ values = torch.stack(values, dim=0)
240
+ dist.all_reduce(values)
241
+ if average:
242
+ values /= world_size
243
+ reduced_dict = {k: v for k, v in zip(names, values)}
244
+ return reduced_dict
245
+
246
+
247
+ class MetricLogger(object):
248
+ def __init__(self, delimiter="\t"):
249
+ self.meters = defaultdict(SmoothedValue)
250
+ self.delimiter = delimiter
251
+
252
+ def update(self, **kwargs):
253
+ for k, v in kwargs.items():
254
+ if isinstance(v, torch.Tensor):
255
+ v = v.item()
256
+ assert isinstance(v, (float, int))
257
+ self.meters[k].update(v)
258
+
259
+ def __getattr__(self, attr):
260
+ if attr in self.meters:
261
+ return self.meters[attr]
262
+ if attr in self.__dict__:
263
+ return self.__dict__[attr]
264
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
265
+
266
+ def __str__(self):
267
+ loss_str = []
268
+ for name, meter in self.meters.items():
269
+ # print(name, str(meter))
270
+ # import ipdb;ipdb.set_trace()
271
+ if meter.count > 0:
272
+ loss_str.append("{}: {}".format(name, str(meter)))
273
+ return self.delimiter.join(loss_str)
274
+
275
+ def synchronize_between_processes(self):
276
+ for meter in self.meters.values():
277
+ meter.synchronize_between_processes()
278
+
279
+ def add_meter(self, name, meter):
280
+ self.meters[name] = meter
281
+
282
+ def log_every(self, iterable, print_freq, header=None, logger=None):
283
+ if logger is None:
284
+ print_func = print
285
+ else:
286
+ print_func = logger.info
287
+
288
+ i = 0
289
+ if not header:
290
+ header = ""
291
+ start_time = time.time()
292
+ end = time.time()
293
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
294
+ data_time = SmoothedValue(fmt="{avg:.4f}")
295
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
296
+ if torch.cuda.is_available():
297
+ log_msg = self.delimiter.join(
298
+ [
299
+ header,
300
+ "[{0" + space_fmt + "}/{1}]",
301
+ "eta: {eta}",
302
+ "{meters}",
303
+ "time: {time}",
304
+ "data: {data}",
305
+ "max mem: {memory:.0f}",
306
+ ]
307
+ )
308
+ else:
309
+ log_msg = self.delimiter.join(
310
+ [
311
+ header,
312
+ "[{0" + space_fmt + "}/{1}]",
313
+ "eta: {eta}",
314
+ "{meters}",
315
+ "time: {time}",
316
+ "data: {data}",
317
+ ]
318
+ )
319
+ MB = 1024.0 * 1024.0
320
+ for obj in iterable:
321
+ data_time.update(time.time() - end)
322
+ yield obj
323
+ # import ipdb; ipdb.set_trace()
324
+ iter_time.update(time.time() - end)
325
+ if i % print_freq == 0 or i == len(iterable) - 1:
326
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
327
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
328
+ if torch.cuda.is_available():
329
+ print_func(
330
+ log_msg.format(
331
+ i,
332
+ len(iterable),
333
+ eta=eta_string,
334
+ meters=str(self),
335
+ time=str(iter_time),
336
+ data=str(data_time),
337
+ memory=torch.cuda.max_memory_allocated() / MB,
338
+ )
339
+ )
340
+ else:
341
+ print_func(
342
+ log_msg.format(
343
+ i,
344
+ len(iterable),
345
+ eta=eta_string,
346
+ meters=str(self),
347
+ time=str(iter_time),
348
+ data=str(data_time),
349
+ )
350
+ )
351
+ i += 1
352
+ end = time.time()
353
+ total_time = time.time() - start_time
354
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
355
+ print_func(
356
+ "{} Total time: {} ({:.4f} s / it)".format(
357
+ header, total_time_str, total_time / len(iterable)
358
+ )
359
+ )
360
+
361
+
362
+ def get_sha():
363
+ cwd = os.path.dirname(os.path.abspath(__file__))
364
+
365
+ def _run(command):
366
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
367
+
368
+ sha = "N/A"
369
+ diff = "clean"
370
+ branch = "N/A"
371
+ try:
372
+ sha = _run(["git", "rev-parse", "HEAD"])
373
+ subprocess.check_output(["git", "diff"], cwd=cwd)
374
+ diff = _run(["git", "diff-index", "HEAD"])
375
+ diff = "has uncommited changes" if diff else "clean"
376
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
377
+ except Exception:
378
+ pass
379
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
380
+ return message
381
+
382
+
383
+ def collate_fn(batch):
384
+ # import ipdb; ipdb.set_trace()
385
+ batch = list(zip(*batch))
386
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
387
+ return tuple(batch)
388
+
389
+
390
+ def _max_by_axis(the_list):
391
+ # type: (List[List[int]]) -> List[int]
392
+ maxes = the_list[0]
393
+ for sublist in the_list[1:]:
394
+ for index, item in enumerate(sublist):
395
+ maxes[index] = max(maxes[index], item)
396
+ return maxes
397
+
398
+
399
+ class NestedTensor(object):
400
+ def __init__(self, tensors, mask: Optional[Tensor]):
401
+ self.tensors = tensors
402
+ self.mask = mask
403
+ if mask == "auto":
404
+ self.mask = torch.zeros_like(tensors).to(tensors.device)
405
+ if self.mask.dim() == 3:
406
+ self.mask = self.mask.sum(0).to(bool)
407
+ elif self.mask.dim() == 4:
408
+ self.mask = self.mask.sum(1).to(bool)
409
+ else:
410
+ raise ValueError(
411
+ "tensors dim must be 3 or 4 but {}({})".format(
412
+ self.tensors.dim(), self.tensors.shape
413
+ )
414
+ )
415
+
416
+ def imgsize(self):
417
+ res = []
418
+ for i in range(self.tensors.shape[0]):
419
+ mask = self.mask[i]
420
+ maxH = (~mask).sum(0).max()
421
+ maxW = (~mask).sum(1).max()
422
+ res.append(torch.Tensor([maxH, maxW]))
423
+ return res
424
+
425
+ def to(self, device):
426
+ # type: (Device) -> NestedTensor # noqa
427
+ cast_tensor = self.tensors.to(device)
428
+ mask = self.mask
429
+ if mask is not None:
430
+ assert mask is not None
431
+ cast_mask = mask.to(device)
432
+ else:
433
+ cast_mask = None
434
+ return NestedTensor(cast_tensor, cast_mask)
435
+
436
+ def to_img_list_single(self, tensor, mask):
437
+ assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
438
+ maxH = (~mask).sum(0).max()
439
+ maxW = (~mask).sum(1).max()
440
+ img = tensor[:, :maxH, :maxW]
441
+ return img
442
+
443
+ def to_img_list(self):
444
+ """remove the padding and convert to img list
445
+
446
+ Returns:
447
+ [type]: [description]
448
+ """
449
+ if self.tensors.dim() == 3:
450
+ return self.to_img_list_single(self.tensors, self.mask)
451
+ else:
452
+ res = []
453
+ for i in range(self.tensors.shape[0]):
454
+ tensor_i = self.tensors[i]
455
+ mask_i = self.mask[i]
456
+ res.append(self.to_img_list_single(tensor_i, mask_i))
457
+ return res
458
+
459
+ @property
460
+ def device(self):
461
+ return self.tensors.device
462
+
463
+ def decompose(self):
464
+ return self.tensors, self.mask
465
+
466
+ def __repr__(self):
467
+ return str(self.tensors)
468
+
469
+ @property
470
+ def shape(self):
471
+ return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
472
+
473
+
474
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
475
+ # TODO make this more general
476
+ if tensor_list[0].ndim == 3:
477
+ if torchvision._is_tracing():
478
+ # nested_tensor_from_tensor_list() does not export well to ONNX
479
+ # call _onnx_nested_tensor_from_tensor_list() instead
480
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
481
+
482
+ # TODO make it support different-sized images
483
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
484
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
485
+ batch_shape = [len(tensor_list)] + max_size
486
+ b, c, h, w = batch_shape
487
+ dtype = tensor_list[0].dtype
488
+ device = tensor_list[0].device
489
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
490
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
491
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
492
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
493
+ m[: img.shape[1], : img.shape[2]] = False
494
+ else:
495
+ raise ValueError("not supported")
496
+ return NestedTensor(tensor, mask)
497
+
498
+
499
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
500
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
501
+ @torch.jit.unused
502
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
503
+ max_size = []
504
+ for i in range(tensor_list[0].dim()):
505
+ max_size_i = torch.max(
506
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
507
+ ).to(torch.int64)
508
+ max_size.append(max_size_i)
509
+ max_size = tuple(max_size)
510
+
511
+ # work around for
512
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
513
+ # m[: img.shape[1], :img.shape[2]] = False
514
+ # which is not yet supported in onnx
515
+ padded_imgs = []
516
+ padded_masks = []
517
+ for img in tensor_list:
518
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
519
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
520
+ padded_imgs.append(padded_img)
521
+
522
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
523
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
524
+ padded_masks.append(padded_mask.to(torch.bool))
525
+
526
+ tensor = torch.stack(padded_imgs)
527
+ mask = torch.stack(padded_masks)
528
+
529
+ return NestedTensor(tensor, mask=mask)
530
+
531
+
532
+ def setup_for_distributed(is_master):
533
+ """
534
+ This function disables printing when not in master process
535
+ """
536
+ import builtins as __builtin__
537
+
538
+ builtin_print = __builtin__.print
539
+
540
+ def print(*args, **kwargs):
541
+ force = kwargs.pop("force", False)
542
+ if is_master or force:
543
+ builtin_print(*args, **kwargs)
544
+
545
+ __builtin__.print = print
546
+
547
+
548
+ def is_dist_avail_and_initialized():
549
+ if not dist.is_available():
550
+ return False
551
+ if not dist.is_initialized():
552
+ return False
553
+ return True
554
+
555
+
556
+ def get_world_size():
557
+ if not is_dist_avail_and_initialized():
558
+ return 1
559
+ return dist.get_world_size()
560
+
561
+
562
+ def get_rank():
563
+ if not is_dist_avail_and_initialized():
564
+ return 0
565
+ return dist.get_rank()
566
+
567
+
568
+ def is_main_process():
569
+ return get_rank() == 0
570
+
571
+
572
+ def save_on_master(*args, **kwargs):
573
+ if is_main_process():
574
+ torch.save(*args, **kwargs)
575
+
576
+
577
+ def init_distributed_mode(args):
578
+ if "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "": # 'RANK' in os.environ and
579
+ args.rank = int(os.environ["RANK"])
580
+ args.world_size = int(os.environ["WORLD_SIZE"])
581
+ args.gpu = args.local_rank = int(os.environ["LOCAL_RANK"])
582
+
583
+ # launch by torch.distributed.launch
584
+ # Single node
585
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
586
+ # Multi nodes
587
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
588
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
589
+ # args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
590
+ # local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
591
+ # args.world_size = args.world_size * local_world_size
592
+ # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
593
+ # args.rank = args.rank * local_world_size + args.local_rank
594
+ print(
595
+ "world size: {}, rank: {}, local rank: {}".format(
596
+ args.world_size, args.rank, args.local_rank
597
+ )
598
+ )
599
+ print(json.dumps(dict(os.environ), indent=2))
600
+ elif "SLURM_PROCID" in os.environ:
601
+ args.rank = int(os.environ["SLURM_PROCID"])
602
+ args.gpu = args.local_rank = int(os.environ["SLURM_LOCALID"])
603
+ args.world_size = int(os.environ["SLURM_NPROCS"])
604
+
605
+ print(
606
+ "world size: {}, world rank: {}, local rank: {}, device_count: {}".format(
607
+ args.world_size, args.rank, args.local_rank, torch.cuda.device_count()
608
+ )
609
+ )
610
+ else:
611
+ print("Not using distributed mode")
612
+ args.distributed = False
613
+ args.world_size = 1
614
+ args.rank = 0
615
+ args.local_rank = 0
616
+ return
617
+
618
+ print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
619
+ args.distributed = True
620
+ torch.cuda.set_device(args.local_rank)
621
+ args.dist_backend = "nccl"
622
+ print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
623
+
624
+ torch.distributed.init_process_group(
625
+ backend=args.dist_backend,
626
+ world_size=args.world_size,
627
+ rank=args.rank,
628
+ init_method=args.dist_url,
629
+ )
630
+
631
+ print("Before torch.distributed.barrier()")
632
+ torch.distributed.barrier()
633
+ print("End torch.distributed.barrier()")
634
+ setup_for_distributed(args.rank == 0)
635
+
636
+
637
+ @torch.no_grad()
638
+ def accuracy(output, target, topk=(1,)):
639
+ """Computes the precision@k for the specified values of k"""
640
+ if target.numel() == 0:
641
+ return [torch.zeros([], device=output.device)]
642
+ maxk = max(topk)
643
+ batch_size = target.size(0)
644
+
645
+ _, pred = output.topk(maxk, 1, True, True)
646
+ pred = pred.t()
647
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
648
+
649
+ res = []
650
+ for k in topk:
651
+ correct_k = correct[:k].view(-1).float().sum(0)
652
+ res.append(correct_k.mul_(100.0 / batch_size))
653
+ return res
654
+
655
+
656
+ @torch.no_grad()
657
+ def accuracy_onehot(pred, gt):
658
+ """_summary_
659
+
660
+ Args:
661
+ pred (_type_): n, c
662
+ gt (_type_): n, c
663
+ """
664
+ tp = ((pred - gt).abs().sum(-1) < 1e-4).float().sum()
665
+ acc = tp / gt.shape[0] * 100
666
+ return acc
667
+
668
+
669
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
670
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
671
+ """
672
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
673
+ This will eventually be supported natively by PyTorch, and this
674
+ class can go away.
675
+ """
676
+ if __torchvision_need_compat_flag < 0.7:
677
+ if input.numel() > 0:
678
+ return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
679
+
680
+ output_shape = _output_size(2, input, size, scale_factor)
681
+ output_shape = list(input.shape[:-2]) + list(output_shape)
682
+ return _new_empty_tensor(input, output_shape)
683
+ else:
684
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
685
+
686
+
687
+ class color_sys:
688
+ def __init__(self, num_colors) -> None:
689
+ self.num_colors = num_colors
690
+ colors = []
691
+ for i in np.arange(0.0, 360.0, 360.0 / num_colors):
692
+ hue = i / 360.0
693
+ lightness = (50 + np.random.rand() * 10) / 100.0
694
+ saturation = (90 + np.random.rand() * 10) / 100.0
695
+ colors.append(
696
+ tuple([int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)])
697
+ )
698
+ self.colors = colors
699
+
700
+ def __call__(self, idx):
701
+ return self.colors[idx]
702
+
703
+
704
+ def inverse_sigmoid(x, eps=1e-3):
705
+ x = x.clamp(min=0, max=1)
706
+ x1 = x.clamp(min=eps)
707
+ x2 = (1 - x).clamp(min=eps)
708
+ return torch.log(x1 / x2)
709
+
710
+
711
+ def clean_state_dict(state_dict):
712
+ new_state_dict = OrderedDict()
713
+ for k, v in state_dict.items():
714
+ if k[:7] == "module.":
715
+ k = k[7:] # remove `module.`
716
+ new_state_dict[k] = v
717
+ return new_state_dict
GroundingDINO/groundingdino/util/slconfig.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Modified from mmcv
3
+ # ==========================================================
4
+ import ast
5
+ import os.path as osp
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ from argparse import Action
10
+ from importlib import import_module
11
+ import platform
12
+
13
+ from addict import Dict
14
+ from yapf.yapflib.yapf_api import FormatCode
15
+
16
+ BASE_KEY = "_base_"
17
+ DELETE_KEY = "_delete_"
18
+ RESERVED_KEYS = ["filename", "text", "pretty_text", "get", "dump", "merge_from_dict"]
19
+
20
+
21
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
22
+ if not osp.isfile(filename):
23
+ raise FileNotFoundError(msg_tmpl.format(filename))
24
+
25
+
26
+ class ConfigDict(Dict):
27
+ def __missing__(self, name):
28
+ raise KeyError(name)
29
+
30
+ def __getattr__(self, name):
31
+ try:
32
+ value = super(ConfigDict, self).__getattr__(name)
33
+ except KeyError:
34
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'")
35
+ except Exception as e:
36
+ ex = e
37
+ else:
38
+ return value
39
+ raise ex
40
+
41
+
42
+ class SLConfig(object):
43
+ """
44
+ config files.
45
+ only support .py file as config now.
46
+
47
+ ref: mmcv.utils.config
48
+
49
+ Example:
50
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
51
+ >>> cfg.a
52
+ 1
53
+ >>> cfg.b
54
+ {'b1': [0, 1]}
55
+ >>> cfg.b.b1
56
+ [0, 1]
57
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
58
+ >>> cfg.filename
59
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
60
+ >>> cfg.item4
61
+ 'test'
62
+ >>> cfg
63
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
64
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
65
+ """
66
+
67
+ @staticmethod
68
+ def _validate_py_syntax(filename):
69
+ with open(filename) as f:
70
+ content = f.read()
71
+ try:
72
+ ast.parse(content)
73
+ except SyntaxError:
74
+ raise SyntaxError("There are syntax errors in config " f"file {filename}")
75
+
76
+ @staticmethod
77
+ def _file2dict(filename):
78
+ filename = osp.abspath(osp.expanduser(filename))
79
+ check_file_exist(filename)
80
+ if filename.lower().endswith(".py"):
81
+ with tempfile.TemporaryDirectory() as temp_config_dir:
82
+ temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py")
83
+ temp_config_name = osp.basename(temp_config_file.name)
84
+ if platform.system() == 'Windows':
85
+ temp_config_file.close()
86
+ shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name))
87
+ temp_module_name = osp.splitext(temp_config_name)[0]
88
+ sys.path.insert(0, temp_config_dir)
89
+ SLConfig._validate_py_syntax(filename)
90
+ mod = import_module(temp_module_name)
91
+ sys.path.pop(0)
92
+ cfg_dict = {
93
+ name: value for name, value in mod.__dict__.items() if not name.startswith("__")
94
+ }
95
+ # delete imported module
96
+ del sys.modules[temp_module_name]
97
+ # close temp file
98
+ temp_config_file.close()
99
+ elif filename.lower().endswith((".yml", ".yaml", ".json")):
100
+ from .slio import slload
101
+
102
+ cfg_dict = slload(filename)
103
+ else:
104
+ raise IOError("Only py/yml/yaml/json type are supported now!")
105
+
106
+ cfg_text = filename + "\n"
107
+ with open(filename, "r") as f:
108
+ cfg_text += f.read()
109
+
110
+ # parse the base file
111
+ if BASE_KEY in cfg_dict:
112
+ cfg_dir = osp.dirname(filename)
113
+ base_filename = cfg_dict.pop(BASE_KEY)
114
+ base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
115
+
116
+ cfg_dict_list = list()
117
+ cfg_text_list = list()
118
+ for f in base_filename:
119
+ _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
120
+ cfg_dict_list.append(_cfg_dict)
121
+ cfg_text_list.append(_cfg_text)
122
+
123
+ base_cfg_dict = dict()
124
+ for c in cfg_dict_list:
125
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
126
+ raise KeyError("Duplicate key is not allowed among bases")
127
+ # TODO Allow the duplicate key while warnning user
128
+ base_cfg_dict.update(c)
129
+
130
+ base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
131
+ cfg_dict = base_cfg_dict
132
+
133
+ # merge cfg_text
134
+ cfg_text_list.append(cfg_text)
135
+ cfg_text = "\n".join(cfg_text_list)
136
+
137
+ return cfg_dict, cfg_text
138
+
139
+ @staticmethod
140
+ def _merge_a_into_b(a, b):
141
+ """merge dict `a` into dict `b` (non-inplace).
142
+ values in `a` will overwrite `b`.
143
+ copy first to avoid inplace modification
144
+
145
+ Args:
146
+ a ([type]): [description]
147
+ b ([type]): [description]
148
+
149
+ Returns:
150
+ [dict]: [description]
151
+ """
152
+ # import ipdb; ipdb.set_trace()
153
+ if not isinstance(a, dict):
154
+ return a
155
+
156
+ b = b.copy()
157
+ for k, v in a.items():
158
+ if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
159
+
160
+ if not isinstance(b[k], dict) and not isinstance(b[k], list):
161
+ # if :
162
+ # import ipdb; ipdb.set_trace()
163
+ raise TypeError(
164
+ f"{k}={v} in child config cannot inherit from base "
165
+ f"because {k} is a dict in the child config but is of "
166
+ f"type {type(b[k])} in base config. You may set "
167
+ f"`{DELETE_KEY}=True` to ignore the base config"
168
+ )
169
+ b[k] = SLConfig._merge_a_into_b(v, b[k])
170
+ elif isinstance(b, list):
171
+ try:
172
+ _ = int(k)
173
+ except:
174
+ raise TypeError(
175
+ f"b is a list, " f"index {k} should be an int when input but {type(k)}"
176
+ )
177
+ b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
178
+ else:
179
+ b[k] = v
180
+
181
+ return b
182
+
183
+ @staticmethod
184
+ def fromfile(filename):
185
+ cfg_dict, cfg_text = SLConfig._file2dict(filename)
186
+ return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
187
+
188
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
189
+ if cfg_dict is None:
190
+ cfg_dict = dict()
191
+ elif not isinstance(cfg_dict, dict):
192
+ raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
193
+ for key in cfg_dict:
194
+ if key in RESERVED_KEYS:
195
+ raise KeyError(f"{key} is reserved for config file")
196
+
197
+ super(SLConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
198
+ super(SLConfig, self).__setattr__("_filename", filename)
199
+ if cfg_text:
200
+ text = cfg_text
201
+ elif filename:
202
+ with open(filename, "r") as f:
203
+ text = f.read()
204
+ else:
205
+ text = ""
206
+ super(SLConfig, self).__setattr__("_text", text)
207
+
208
+ @property
209
+ def filename(self):
210
+ return self._filename
211
+
212
+ @property
213
+ def text(self):
214
+ return self._text
215
+
216
+ @property
217
+ def pretty_text(self):
218
+
219
+ indent = 4
220
+
221
+ def _indent(s_, num_spaces):
222
+ s = s_.split("\n")
223
+ if len(s) == 1:
224
+ return s_
225
+ first = s.pop(0)
226
+ s = [(num_spaces * " ") + line for line in s]
227
+ s = "\n".join(s)
228
+ s = first + "\n" + s
229
+ return s
230
+
231
+ def _format_basic_types(k, v, use_mapping=False):
232
+ if isinstance(v, str):
233
+ v_str = f"'{v}'"
234
+ else:
235
+ v_str = str(v)
236
+
237
+ if use_mapping:
238
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
239
+ attr_str = f"{k_str}: {v_str}"
240
+ else:
241
+ attr_str = f"{str(k)}={v_str}"
242
+ attr_str = _indent(attr_str, indent)
243
+
244
+ return attr_str
245
+
246
+ def _format_list(k, v, use_mapping=False):
247
+ # check if all items in the list are dict
248
+ if all(isinstance(_, dict) for _ in v):
249
+ v_str = "[\n"
250
+ v_str += "\n".join(
251
+ f"dict({_indent(_format_dict(v_), indent)})," for v_ in v
252
+ ).rstrip(",")
253
+ if use_mapping:
254
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
255
+ attr_str = f"{k_str}: {v_str}"
256
+ else:
257
+ attr_str = f"{str(k)}={v_str}"
258
+ attr_str = _indent(attr_str, indent) + "]"
259
+ else:
260
+ attr_str = _format_basic_types(k, v, use_mapping)
261
+ return attr_str
262
+
263
+ def _contain_invalid_identifier(dict_str):
264
+ contain_invalid_identifier = False
265
+ for key_name in dict_str:
266
+ contain_invalid_identifier |= not str(key_name).isidentifier()
267
+ return contain_invalid_identifier
268
+
269
+ def _format_dict(input_dict, outest_level=False):
270
+ r = ""
271
+ s = []
272
+
273
+ use_mapping = _contain_invalid_identifier(input_dict)
274
+ if use_mapping:
275
+ r += "{"
276
+ for idx, (k, v) in enumerate(input_dict.items()):
277
+ is_last = idx >= len(input_dict) - 1
278
+ end = "" if outest_level or is_last else ","
279
+ if isinstance(v, dict):
280
+ v_str = "\n" + _format_dict(v)
281
+ if use_mapping:
282
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
283
+ attr_str = f"{k_str}: dict({v_str}"
284
+ else:
285
+ attr_str = f"{str(k)}=dict({v_str}"
286
+ attr_str = _indent(attr_str, indent) + ")" + end
287
+ elif isinstance(v, list):
288
+ attr_str = _format_list(k, v, use_mapping) + end
289
+ else:
290
+ attr_str = _format_basic_types(k, v, use_mapping) + end
291
+
292
+ s.append(attr_str)
293
+ r += "\n".join(s)
294
+ if use_mapping:
295
+ r += "}"
296
+ return r
297
+
298
+ cfg_dict = self._cfg_dict.to_dict()
299
+ text = _format_dict(cfg_dict, outest_level=True)
300
+ # copied from setup.cfg
301
+ yapf_style = dict(
302
+ based_on_style="pep8",
303
+ blank_line_before_nested_class_or_def=True,
304
+ split_before_expression_after_opening_paren=True,
305
+ )
306
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
307
+
308
+ return text
309
+
310
+ def __repr__(self):
311
+ return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
312
+
313
+ def __len__(self):
314
+ return len(self._cfg_dict)
315
+
316
+ def __getattr__(self, name):
317
+ # # debug
318
+ # print('+'*15)
319
+ # print('name=%s' % name)
320
+ # print("addr:", id(self))
321
+ # # print('type(self):', type(self))
322
+ # print(self.__dict__)
323
+ # print('+'*15)
324
+ # if self.__dict__ == {}:
325
+ # raise ValueError
326
+
327
+ return getattr(self._cfg_dict, name)
328
+
329
+ def __getitem__(self, name):
330
+ return self._cfg_dict.__getitem__(name)
331
+
332
+ def __setattr__(self, name, value):
333
+ if isinstance(value, dict):
334
+ value = ConfigDict(value)
335
+ self._cfg_dict.__setattr__(name, value)
336
+
337
+ def __setitem__(self, name, value):
338
+ if isinstance(value, dict):
339
+ value = ConfigDict(value)
340
+ self._cfg_dict.__setitem__(name, value)
341
+
342
+ def __iter__(self):
343
+ return iter(self._cfg_dict)
344
+
345
+ def dump(self, file=None):
346
+ # import ipdb; ipdb.set_trace()
347
+ if file is None:
348
+ return self.pretty_text
349
+ else:
350
+ with open(file, "w") as f:
351
+ f.write(self.pretty_text)
352
+
353
+ def merge_from_dict(self, options):
354
+ """Merge list into cfg_dict
355
+
356
+ Merge the dict parsed by MultipleKVAction into this cfg.
357
+
358
+ Examples:
359
+ >>> options = {'model.backbone.depth': 50,
360
+ ... 'model.backbone.with_cp':True}
361
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
362
+ >>> cfg.merge_from_dict(options)
363
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
364
+ >>> assert cfg_dict == dict(
365
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
366
+
367
+ Args:
368
+ options (dict): dict of configs to merge from.
369
+ """
370
+ option_cfg_dict = {}
371
+ for full_key, v in options.items():
372
+ d = option_cfg_dict
373
+ key_list = full_key.split(".")
374
+ for subkey in key_list[:-1]:
375
+ d.setdefault(subkey, ConfigDict())
376
+ d = d[subkey]
377
+ subkey = key_list[-1]
378
+ d[subkey] = v
379
+
380
+ cfg_dict = super(SLConfig, self).__getattribute__("_cfg_dict")
381
+ super(SLConfig, self).__setattr__(
382
+ "_cfg_dict", SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict)
383
+ )
384
+
385
+ # for multiprocess
386
+ def __setstate__(self, state):
387
+ self.__init__(state)
388
+
389
+ def copy(self):
390
+ return SLConfig(self._cfg_dict.copy())
391
+
392
+ def deepcopy(self):
393
+ return SLConfig(self._cfg_dict.deepcopy())
394
+
395
+
396
+ class DictAction(Action):
397
+ """
398
+ argparse action to split an argument into KEY=VALUE form
399
+ on the first = and append to a dictionary. List options should
400
+ be passed as comma separated values, i.e KEY=V1,V2,V3
401
+ """
402
+
403
+ @staticmethod
404
+ def _parse_int_float_bool(val):
405
+ try:
406
+ return int(val)
407
+ except ValueError:
408
+ pass
409
+ try:
410
+ return float(val)
411
+ except ValueError:
412
+ pass
413
+ if val.lower() in ["true", "false"]:
414
+ return True if val.lower() == "true" else False
415
+ if val.lower() in ["none", "null"]:
416
+ return None
417
+ return val
418
+
419
+ def __call__(self, parser, namespace, values, option_string=None):
420
+ options = {}
421
+ for kv in values:
422
+ key, val = kv.split("=", maxsplit=1)
423
+ val = [self._parse_int_float_bool(v) for v in val.split(",")]
424
+ if len(val) == 1:
425
+ val = val[0]
426
+ options[key] = val
427
+ setattr(namespace, self.dest, options)
GroundingDINO/groundingdino/util/slio.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==========================================================
2
+ # Modified from mmcv
3
+ # ==========================================================
4
+
5
+ import json
6
+ import pickle
7
+ from abc import ABCMeta, abstractmethod
8
+ from pathlib import Path
9
+
10
+ import yaml
11
+
12
+ try:
13
+ from yaml import CLoader as Loader, CDumper as Dumper
14
+ except ImportError:
15
+ from yaml import Loader, Dumper
16
+
17
+
18
+ # ===========================
19
+ # Rigister handler
20
+ # ===========================
21
+
22
+
23
+ class BaseFileHandler(metaclass=ABCMeta):
24
+ @abstractmethod
25
+ def load_from_fileobj(self, file, **kwargs):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def dump_to_fileobj(self, obj, file, **kwargs):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def dump_to_str(self, obj, **kwargs):
34
+ pass
35
+
36
+ def load_from_path(self, filepath, mode="r", **kwargs):
37
+ with open(filepath, mode) as f:
38
+ return self.load_from_fileobj(f, **kwargs)
39
+
40
+ def dump_to_path(self, obj, filepath, mode="w", **kwargs):
41
+ with open(filepath, mode) as f:
42
+ self.dump_to_fileobj(obj, f, **kwargs)
43
+
44
+
45
+ class JsonHandler(BaseFileHandler):
46
+ def load_from_fileobj(self, file):
47
+ return json.load(file)
48
+
49
+ def dump_to_fileobj(self, obj, file, **kwargs):
50
+ json.dump(obj, file, **kwargs)
51
+
52
+ def dump_to_str(self, obj, **kwargs):
53
+ return json.dumps(obj, **kwargs)
54
+
55
+
56
+ class PickleHandler(BaseFileHandler):
57
+ def load_from_fileobj(self, file, **kwargs):
58
+ return pickle.load(file, **kwargs)
59
+
60
+ def load_from_path(self, filepath, **kwargs):
61
+ return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
62
+
63
+ def dump_to_str(self, obj, **kwargs):
64
+ kwargs.setdefault("protocol", 2)
65
+ return pickle.dumps(obj, **kwargs)
66
+
67
+ def dump_to_fileobj(self, obj, file, **kwargs):
68
+ kwargs.setdefault("protocol", 2)
69
+ pickle.dump(obj, file, **kwargs)
70
+
71
+ def dump_to_path(self, obj, filepath, **kwargs):
72
+ super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
73
+
74
+
75
+ class YamlHandler(BaseFileHandler):
76
+ def load_from_fileobj(self, file, **kwargs):
77
+ kwargs.setdefault("Loader", Loader)
78
+ return yaml.load(file, **kwargs)
79
+
80
+ def dump_to_fileobj(self, obj, file, **kwargs):
81
+ kwargs.setdefault("Dumper", Dumper)
82
+ yaml.dump(obj, file, **kwargs)
83
+
84
+ def dump_to_str(self, obj, **kwargs):
85
+ kwargs.setdefault("Dumper", Dumper)
86
+ return yaml.dump(obj, **kwargs)
87
+
88
+
89
+ file_handlers = {
90
+ "json": JsonHandler(),
91
+ "yaml": YamlHandler(),
92
+ "yml": YamlHandler(),
93
+ "pickle": PickleHandler(),
94
+ "pkl": PickleHandler(),
95
+ }
96
+
97
+ # ===========================
98
+ # load and dump
99
+ # ===========================
100
+
101
+
102
+ def is_str(x):
103
+ """Whether the input is an string instance.
104
+
105
+ Note: This method is deprecated since python 2 is no longer supported.
106
+ """
107
+ return isinstance(x, str)
108
+
109
+
110
+ def slload(file, file_format=None, **kwargs):
111
+ """Load data from json/yaml/pickle files.
112
+
113
+ This method provides a unified api for loading data from serialized files.
114
+
115
+ Args:
116
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
117
+ object.
118
+ file_format (str, optional): If not specified, the file format will be
119
+ inferred from the file extension, otherwise use the specified one.
120
+ Currently supported formats include "json", "yaml/yml" and
121
+ "pickle/pkl".
122
+
123
+ Returns:
124
+ The content from the file.
125
+ """
126
+ if isinstance(file, Path):
127
+ file = str(file)
128
+ if file_format is None and is_str(file):
129
+ file_format = file.split(".")[-1]
130
+ if file_format not in file_handlers:
131
+ raise TypeError(f"Unsupported format: {file_format}")
132
+
133
+ handler = file_handlers[file_format]
134
+ if is_str(file):
135
+ obj = handler.load_from_path(file, **kwargs)
136
+ elif hasattr(file, "read"):
137
+ obj = handler.load_from_fileobj(file, **kwargs)
138
+ else:
139
+ raise TypeError('"file" must be a filepath str or a file-object')
140
+ return obj
141
+
142
+
143
+ def sldump(obj, file=None, file_format=None, **kwargs):
144
+ """Dump data to json/yaml/pickle strings or files.
145
+
146
+ This method provides a unified api for dumping data as strings or to files,
147
+ and also supports custom arguments for each file format.
148
+
149
+ Args:
150
+ obj (any): The python object to be dumped.
151
+ file (str or :obj:`Path` or file-like object, optional): If not
152
+ specified, then the object is dump to a str, otherwise to a file
153
+ specified by the filename or file-like object.
154
+ file_format (str, optional): Same as :func:`load`.
155
+
156
+ Returns:
157
+ bool: True for success, False otherwise.
158
+ """
159
+ if isinstance(file, Path):
160
+ file = str(file)
161
+ if file_format is None:
162
+ if is_str(file):
163
+ file_format = file.split(".")[-1]
164
+ elif file is None:
165
+ raise ValueError("file_format must be specified since file is None")
166
+ if file_format not in file_handlers:
167
+ raise TypeError(f"Unsupported format: {file_format}")
168
+
169
+ handler = file_handlers[file_format]
170
+ if file is None:
171
+ return handler.dump_to_str(obj, **kwargs)
172
+ elif is_str(file):
173
+ handler.dump_to_path(obj, file, **kwargs)
174
+ elif hasattr(file, "write"):
175
+ handler.dump_to_fileobj(obj, file, **kwargs)
176
+ else:
177
+ raise TypeError('"file" must be a filename str or a file-object')
GroundingDINO/groundingdino/util/time_counter.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+
5
+ class TimeCounter:
6
+ def __init__(self) -> None:
7
+ pass
8
+
9
+ def clear(self):
10
+ self.timedict = {}
11
+ self.basetime = time.perf_counter()
12
+
13
+ def timeit(self, name):
14
+ nowtime = time.perf_counter() - self.basetime
15
+ self.timedict[name] = nowtime
16
+ self.basetime = time.perf_counter()
17
+
18
+
19
+ class TimeHolder:
20
+ def __init__(self) -> None:
21
+ self.timedict = {}
22
+
23
+ def update(self, _timedict: dict):
24
+ for k, v in _timedict.items():
25
+ if k not in self.timedict:
26
+ self.timedict[k] = AverageMeter(name=k, val_only=True)
27
+ self.timedict[k].update(val=v)
28
+
29
+ def final_res(self):
30
+ return {k: v.avg for k, v in self.timedict.items()}
31
+
32
+ def __str__(self):
33
+ return json.dumps(self.final_res(), indent=2)
34
+
35
+
36
+ class AverageMeter(object):
37
+ """Computes and stores the average and current value"""
38
+
39
+ def __init__(self, name, fmt=":f", val_only=False):
40
+ self.name = name
41
+ self.fmt = fmt
42
+ self.val_only = val_only
43
+ self.reset()
44
+
45
+ def reset(self):
46
+ self.val = 0
47
+ self.avg = 0
48
+ self.sum = 0
49
+ self.count = 0
50
+
51
+ def update(self, val, n=1):
52
+ self.val = val
53
+ self.sum += val * n
54
+ self.count += n
55
+ self.avg = self.sum / self.count
56
+
57
+ def __str__(self):
58
+ if self.val_only:
59
+ fmtstr = "{name} {val" + self.fmt + "}"
60
+ else:
61
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
62
+ return fmtstr.format(**self.__dict__)
GroundingDINO/groundingdino/util/utils.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import warnings
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from typing import Any, Dict, List
7
+
8
+ import numpy as np
9
+ import torch
10
+ from transformers import AutoTokenizer
11
+
12
+ from groundingdino.util.slconfig import SLConfig
13
+
14
+
15
+ def slprint(x, name="x"):
16
+ if isinstance(x, (torch.Tensor, np.ndarray)):
17
+ print(f"{name}.shape:", x.shape)
18
+ elif isinstance(x, (tuple, list)):
19
+ print("type x:", type(x))
20
+ for i in range(min(10, len(x))):
21
+ slprint(x[i], f"{name}[{i}]")
22
+ elif isinstance(x, dict):
23
+ for k, v in x.items():
24
+ slprint(v, f"{name}[{k}]")
25
+ else:
26
+ print(f"{name}.type:", type(x))
27
+
28
+
29
+ def clean_state_dict(state_dict):
30
+ new_state_dict = OrderedDict()
31
+ for k, v in state_dict.items():
32
+ if k[:7] == "module.":
33
+ k = k[7:] # remove `module.`
34
+ new_state_dict[k] = v
35
+ return new_state_dict
36
+
37
+
38
+ def renorm(
39
+ img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
40
+ ) -> torch.FloatTensor:
41
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
42
+ # return: same as img
43
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
44
+ if img.dim() == 3:
45
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
46
+ img.size(0),
47
+ str(img.size()),
48
+ )
49
+ img_perm = img.permute(1, 2, 0)
50
+ mean = torch.Tensor(mean)
51
+ std = torch.Tensor(std)
52
+ img_res = img_perm * std + mean
53
+ return img_res.permute(2, 0, 1)
54
+ else: # img.dim() == 4
55
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
56
+ img.size(1),
57
+ str(img.size()),
58
+ )
59
+ img_perm = img.permute(0, 2, 3, 1)
60
+ mean = torch.Tensor(mean)
61
+ std = torch.Tensor(std)
62
+ img_res = img_perm * std + mean
63
+ return img_res.permute(0, 3, 1, 2)
64
+
65
+
66
+ class CocoClassMapper:
67
+ def __init__(self) -> None:
68
+ self.category_map_str = {
69
+ "1": 1,
70
+ "2": 2,
71
+ "3": 3,
72
+ "4": 4,
73
+ "5": 5,
74
+ "6": 6,
75
+ "7": 7,
76
+ "8": 8,
77
+ "9": 9,
78
+ "10": 10,
79
+ "11": 11,
80
+ "13": 12,
81
+ "14": 13,
82
+ "15": 14,
83
+ "16": 15,
84
+ "17": 16,
85
+ "18": 17,
86
+ "19": 18,
87
+ "20": 19,
88
+ "21": 20,
89
+ "22": 21,
90
+ "23": 22,
91
+ "24": 23,
92
+ "25": 24,
93
+ "27": 25,
94
+ "28": 26,
95
+ "31": 27,
96
+ "32": 28,
97
+ "33": 29,
98
+ "34": 30,
99
+ "35": 31,
100
+ "36": 32,
101
+ "37": 33,
102
+ "38": 34,
103
+ "39": 35,
104
+ "40": 36,
105
+ "41": 37,
106
+ "42": 38,
107
+ "43": 39,
108
+ "44": 40,
109
+ "46": 41,
110
+ "47": 42,
111
+ "48": 43,
112
+ "49": 44,
113
+ "50": 45,
114
+ "51": 46,
115
+ "52": 47,
116
+ "53": 48,
117
+ "54": 49,
118
+ "55": 50,
119
+ "56": 51,
120
+ "57": 52,
121
+ "58": 53,
122
+ "59": 54,
123
+ "60": 55,
124
+ "61": 56,
125
+ "62": 57,
126
+ "63": 58,
127
+ "64": 59,
128
+ "65": 60,
129
+ "67": 61,
130
+ "70": 62,
131
+ "72": 63,
132
+ "73": 64,
133
+ "74": 65,
134
+ "75": 66,
135
+ "76": 67,
136
+ "77": 68,
137
+ "78": 69,
138
+ "79": 70,
139
+ "80": 71,
140
+ "81": 72,
141
+ "82": 73,
142
+ "84": 74,
143
+ "85": 75,
144
+ "86": 76,
145
+ "87": 77,
146
+ "88": 78,
147
+ "89": 79,
148
+ "90": 80,
149
+ }
150
+ self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()}
151
+ self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()}
152
+
153
+ def origin2compact(self, idx):
154
+ return self.origin2compact_mapper[int(idx)]
155
+
156
+ def compact2origin(self, idx):
157
+ return self.compact2origin_mapper[int(idx)]
158
+
159
+
160
+ def to_device(item, device):
161
+ if isinstance(item, torch.Tensor):
162
+ return item.to(device)
163
+ elif isinstance(item, list):
164
+ return [to_device(i, device) for i in item]
165
+ elif isinstance(item, dict):
166
+ return {k: to_device(v, device) for k, v in item.items()}
167
+ else:
168
+ raise NotImplementedError(
169
+ "Call Shilong if you use other containers! type: {}".format(type(item))
170
+ )
171
+
172
+
173
+ #
174
+ def get_gaussian_mean(x, axis, other_axis, softmax=True):
175
+ """
176
+
177
+ Args:
178
+ x (float): Input images(BxCxHxW)
179
+ axis (int): The index for weighted mean
180
+ other_axis (int): The other index
181
+
182
+ Returns: weighted index for axis, BxC
183
+
184
+ """
185
+ mat2line = torch.sum(x, axis=other_axis)
186
+ # mat2line = mat2line / mat2line.mean() * 10
187
+ if softmax:
188
+ u = torch.softmax(mat2line, axis=2)
189
+ else:
190
+ u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
191
+ size = x.shape[axis]
192
+ ind = torch.linspace(0, 1, size).to(x.device)
193
+ batch = x.shape[0]
194
+ channel = x.shape[1]
195
+ index = ind.repeat([batch, channel, 1])
196
+ mean_position = torch.sum(index * u, dim=2)
197
+ return mean_position
198
+
199
+
200
+ def get_expected_points_from_map(hm, softmax=True):
201
+ """get_gaussian_map_from_points
202
+ B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
203
+ softargmax function
204
+
205
+ Args:
206
+ hm (float): Input images(BxCxHxW)
207
+
208
+ Returns:
209
+ weighted index for axis, BxCx2. float between 0 and 1.
210
+
211
+ """
212
+ # hm = 10*hm
213
+ B, C, H, W = hm.shape
214
+ y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
215
+ x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
216
+ # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
217
+ return torch.stack([x_mean, y_mean], dim=2)
218
+
219
+
220
+ # Positional encoding (section 5.1)
221
+ # borrow from nerf
222
+ class Embedder:
223
+ def __init__(self, **kwargs):
224
+ self.kwargs = kwargs
225
+ self.create_embedding_fn()
226
+
227
+ def create_embedding_fn(self):
228
+ embed_fns = []
229
+ d = self.kwargs["input_dims"]
230
+ out_dim = 0
231
+ if self.kwargs["include_input"]:
232
+ embed_fns.append(lambda x: x)
233
+ out_dim += d
234
+
235
+ max_freq = self.kwargs["max_freq_log2"]
236
+ N_freqs = self.kwargs["num_freqs"]
237
+
238
+ if self.kwargs["log_sampling"]:
239
+ freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
240
+ else:
241
+ freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
242
+
243
+ for freq in freq_bands:
244
+ for p_fn in self.kwargs["periodic_fns"]:
245
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
246
+ out_dim += d
247
+
248
+ self.embed_fns = embed_fns
249
+ self.out_dim = out_dim
250
+
251
+ def embed(self, inputs):
252
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
253
+
254
+
255
+ def get_embedder(multires, i=0):
256
+ import torch.nn as nn
257
+
258
+ if i == -1:
259
+ return nn.Identity(), 3
260
+
261
+ embed_kwargs = {
262
+ "include_input": True,
263
+ "input_dims": 3,
264
+ "max_freq_log2": multires - 1,
265
+ "num_freqs": multires,
266
+ "log_sampling": True,
267
+ "periodic_fns": [torch.sin, torch.cos],
268
+ }
269
+
270
+ embedder_obj = Embedder(**embed_kwargs)
271
+ embed = lambda x, eo=embedder_obj: eo.embed(x)
272
+ return embed, embedder_obj.out_dim
273
+
274
+
275
+ class APOPMeter:
276
+ def __init__(self) -> None:
277
+ self.tp = 0
278
+ self.fp = 0
279
+ self.tn = 0
280
+ self.fn = 0
281
+
282
+ def update(self, pred, gt):
283
+ """
284
+ Input:
285
+ pred, gt: Tensor()
286
+ """
287
+ assert pred.shape == gt.shape
288
+ self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
289
+ self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
290
+ self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
291
+ self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
292
+
293
+ def update_cm(self, tp, fp, tn, fn):
294
+ self.tp += tp
295
+ self.fp += fp
296
+ self.tn += tn
297
+ self.tn += fn
298
+
299
+
300
+ def inverse_sigmoid(x, eps=1e-5):
301
+ x = x.clamp(min=0, max=1)
302
+ x1 = x.clamp(min=eps)
303
+ x2 = (1 - x).clamp(min=eps)
304
+ return torch.log(x1 / x2)
305
+
306
+
307
+ def get_raw_dict(args):
308
+ """
309
+ return the dicf contained in args.
310
+
311
+ e.g:
312
+ >>> with open(path, 'w') as f:
313
+ json.dump(get_raw_dict(args), f, indent=2)
314
+ """
315
+ if isinstance(args, argparse.Namespace):
316
+ return vars(args)
317
+ elif isinstance(args, dict):
318
+ return args
319
+ elif isinstance(args, SLConfig):
320
+ return args._cfg_dict
321
+ else:
322
+ raise NotImplementedError("Unknown type {}".format(type(args)))
323
+
324
+
325
+ def stat_tensors(tensor):
326
+ assert tensor.dim() == 1
327
+ tensor_sm = tensor.softmax(0)
328
+ entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
329
+
330
+ return {
331
+ "max": tensor.max(),
332
+ "min": tensor.min(),
333
+ "mean": tensor.mean(),
334
+ "var": tensor.var(),
335
+ "std": tensor.var() ** 0.5,
336
+ "entropy": entropy,
337
+ }
338
+
339
+
340
+ class NiceRepr:
341
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
342
+ objects.
343
+
344
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
345
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
346
+ If the inheriting class has a ``__len__``, method then the default
347
+ ``__nice__`` method will return its length.
348
+
349
+ Example:
350
+ >>> class Foo(NiceRepr):
351
+ ... def __nice__(self):
352
+ ... return 'info'
353
+ >>> foo = Foo()
354
+ >>> assert str(foo) == '<Foo(info)>'
355
+ >>> assert repr(foo).startswith('<Foo(info) at ')
356
+
357
+ Example:
358
+ >>> class Bar(NiceRepr):
359
+ ... pass
360
+ >>> bar = Bar()
361
+ >>> import pytest
362
+ >>> with pytest.warns(None) as record:
363
+ >>> assert 'object at' in str(bar)
364
+ >>> assert 'object at' in repr(bar)
365
+
366
+ Example:
367
+ >>> class Baz(NiceRepr):
368
+ ... def __len__(self):
369
+ ... return 5
370
+ >>> baz = Baz()
371
+ >>> assert str(baz) == '<Baz(5)>'
372
+ """
373
+
374
+ def __nice__(self):
375
+ """str: a "nice" summary string describing this module"""
376
+ if hasattr(self, "__len__"):
377
+ # It is a common pattern for objects to use __len__ in __nice__
378
+ # As a convenience we define a default __nice__ for these objects
379
+ return str(len(self))
380
+ else:
381
+ # In all other cases force the subclass to overload __nice__
382
+ raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}")
383
+
384
+ def __repr__(self):
385
+ """str: the string of the module"""
386
+ try:
387
+ nice = self.__nice__()
388
+ classname = self.__class__.__name__
389
+ return f"<{classname}({nice}) at {hex(id(self))}>"
390
+ except NotImplementedError as ex:
391
+ warnings.warn(str(ex), category=RuntimeWarning)
392
+ return object.__repr__(self)
393
+
394
+ def __str__(self):
395
+ """str: the string of the module"""
396
+ try:
397
+ classname = self.__class__.__name__
398
+ nice = self.__nice__()
399
+ return f"<{classname}({nice})>"
400
+ except NotImplementedError as ex:
401
+ warnings.warn(str(ex), category=RuntimeWarning)
402
+ return object.__repr__(self)
403
+
404
+
405
+ def ensure_rng(rng=None):
406
+ """Coerces input into a random number generator.
407
+
408
+ If the input is None, then a global random state is returned.
409
+
410
+ If the input is a numeric value, then that is used as a seed to construct a
411
+ random state. Otherwise the input is returned as-is.
412
+
413
+ Adapted from [1]_.
414
+
415
+ Args:
416
+ rng (int | numpy.random.RandomState | None):
417
+ if None, then defaults to the global rng. Otherwise this can be an
418
+ integer or a RandomState class
419
+ Returns:
420
+ (numpy.random.RandomState) : rng -
421
+ a numpy random number generator
422
+
423
+ References:
424
+ .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
425
+ """
426
+
427
+ if rng is None:
428
+ rng = np.random.mtrand._rand
429
+ elif isinstance(rng, int):
430
+ rng = np.random.RandomState(rng)
431
+ else:
432
+ rng = rng
433
+ return rng
434
+
435
+
436
+ def random_boxes(num=1, scale=1, rng=None):
437
+ """Simple version of ``kwimage.Boxes.random``
438
+
439
+ Returns:
440
+ Tensor: shape (n, 4) in x1, y1, x2, y2 format.
441
+
442
+ References:
443
+ https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
444
+
445
+ Example:
446
+ >>> num = 3
447
+ >>> scale = 512
448
+ >>> rng = 0
449
+ >>> boxes = random_boxes(num, scale, rng)
450
+ >>> print(boxes)
451
+ tensor([[280.9925, 278.9802, 308.6148, 366.1769],
452
+ [216.9113, 330.6978, 224.0446, 456.5878],
453
+ [405.3632, 196.3221, 493.3953, 270.7942]])
454
+ """
455
+ rng = ensure_rng(rng)
456
+
457
+ tlbr = rng.rand(num, 4).astype(np.float32)
458
+
459
+ tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
460
+ tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
461
+ br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
462
+ br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
463
+
464
+ tlbr[:, 0] = tl_x * scale
465
+ tlbr[:, 1] = tl_y * scale
466
+ tlbr[:, 2] = br_x * scale
467
+ tlbr[:, 3] = br_y * scale
468
+
469
+ boxes = torch.from_numpy(tlbr)
470
+ return boxes
471
+
472
+
473
+ class ModelEma(torch.nn.Module):
474
+ def __init__(self, model, decay=0.9997, device=None):
475
+ super(ModelEma, self).__init__()
476
+ # make a copy of the model for accumulating moving average of weights
477
+ self.module = deepcopy(model)
478
+ self.module.eval()
479
+
480
+ # import ipdb; ipdb.set_trace()
481
+
482
+ self.decay = decay
483
+ self.device = device # perform ema on different device from model if set
484
+ if self.device is not None:
485
+ self.module.to(device=device)
486
+
487
+ def _update(self, model, update_fn):
488
+ with torch.no_grad():
489
+ for ema_v, model_v in zip(
490
+ self.module.state_dict().values(), model.state_dict().values()
491
+ ):
492
+ if self.device is not None:
493
+ model_v = model_v.to(device=self.device)
494
+ ema_v.copy_(update_fn(ema_v, model_v))
495
+
496
+ def update(self, model):
497
+ self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m)
498
+
499
+ def set(self, model):
500
+ self._update(model, update_fn=lambda e, m: m)
501
+
502
+
503
+ class BestMetricSingle:
504
+ def __init__(self, init_res=0.0, better="large") -> None:
505
+ self.init_res = init_res
506
+ self.best_res = init_res
507
+ self.best_ep = -1
508
+
509
+ self.better = better
510
+ assert better in ["large", "small"]
511
+
512
+ def isbetter(self, new_res, old_res):
513
+ if self.better == "large":
514
+ return new_res > old_res
515
+ if self.better == "small":
516
+ return new_res < old_res
517
+
518
+ def update(self, new_res, ep):
519
+ if self.isbetter(new_res, self.best_res):
520
+ self.best_res = new_res
521
+ self.best_ep = ep
522
+ return True
523
+ return False
524
+
525
+ def __str__(self) -> str:
526
+ return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
527
+
528
+ def __repr__(self) -> str:
529
+ return self.__str__()
530
+
531
+ def summary(self) -> dict:
532
+ return {
533
+ "best_res": self.best_res,
534
+ "best_ep": self.best_ep,
535
+ }
536
+
537
+
538
+ class BestMetricHolder:
539
+ def __init__(self, init_res=0.0, better="large", use_ema=False) -> None:
540
+ self.best_all = BestMetricSingle(init_res, better)
541
+ self.use_ema = use_ema
542
+ if use_ema:
543
+ self.best_ema = BestMetricSingle(init_res, better)
544
+ self.best_regular = BestMetricSingle(init_res, better)
545
+
546
+ def update(self, new_res, epoch, is_ema=False):
547
+ """
548
+ return if the results is the best.
549
+ """
550
+ if not self.use_ema:
551
+ return self.best_all.update(new_res, epoch)
552
+ else:
553
+ if is_ema:
554
+ self.best_ema.update(new_res, epoch)
555
+ return self.best_all.update(new_res, epoch)
556
+ else:
557
+ self.best_regular.update(new_res, epoch)
558
+ return self.best_all.update(new_res, epoch)
559
+
560
+ def summary(self):
561
+ if not self.use_ema:
562
+ return self.best_all.summary()
563
+
564
+ res = {}
565
+ res.update({f"all_{k}": v for k, v in self.best_all.summary().items()})
566
+ res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()})
567
+ res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()})
568
+ return res
569
+
570
+ def __repr__(self) -> str:
571
+ return json.dumps(self.summary(), indent=2)
572
+
573
+ def __str__(self) -> str:
574
+ return self.__repr__()
575
+
576
+
577
+ def targets_to(targets: List[Dict[str, Any]], device):
578
+ """Moves the target dicts to the given device."""
579
+ excluded_keys = [
580
+ "questionId",
581
+ "tokens_positive",
582
+ "strings_positive",
583
+ "tokens",
584
+ "dataset_name",
585
+ "sentence_id",
586
+ "original_img_id",
587
+ "nb_eval",
588
+ "task_id",
589
+ "original_id",
590
+ "token_span",
591
+ "caption",
592
+ "dataset_type",
593
+ ]
594
+ return [
595
+ {k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets
596
+ ]
597
+
598
+
599
+ def get_phrases_from_posmap(
600
+ posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
601
+ ):
602
+ assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
603
+ if posmap.dim() == 1:
604
+ non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
605
+ token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
606
+ return tokenizer.decode(token_ids)
607
+ else:
608
+ raise NotImplementedError("posmap must be 1-dim")
GroundingDINO/groundingdino/util/visualizer.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @File : visualizer.py
4
+ @Time : 2022/04/05 11:39:33
5
+ @Author : Shilong Liu
6
+ @Contact : slongliu86@gmail.com
7
+ """
8
+
9
+ import datetime
10
+ import os
11
+
12
+ import cv2
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import torch
16
+ from matplotlib import transforms
17
+ from matplotlib.collections import PatchCollection
18
+ from matplotlib.patches import Polygon
19
+ from pycocotools import mask as maskUtils
20
+
21
+
22
+ def renorm(
23
+ img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
24
+ ) -> torch.FloatTensor:
25
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
26
+ # return: same as img
27
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
28
+ if img.dim() == 3:
29
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
30
+ img.size(0),
31
+ str(img.size()),
32
+ )
33
+ img_perm = img.permute(1, 2, 0)
34
+ mean = torch.Tensor(mean)
35
+ std = torch.Tensor(std)
36
+ img_res = img_perm * std + mean
37
+ return img_res.permute(2, 0, 1)
38
+ else: # img.dim() == 4
39
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
40
+ img.size(1),
41
+ str(img.size()),
42
+ )
43
+ img_perm = img.permute(0, 2, 3, 1)
44
+ mean = torch.Tensor(mean)
45
+ std = torch.Tensor(std)
46
+ img_res = img_perm * std + mean
47
+ return img_res.permute(0, 3, 1, 2)
48
+
49
+
50
+ class ColorMap:
51
+ def __init__(self, basergb=[255, 255, 0]):
52
+ self.basergb = np.array(basergb)
53
+
54
+ def __call__(self, attnmap):
55
+ # attnmap: h, w. np.uint8.
56
+ # return: h, w, 4. np.uint8.
57
+ assert attnmap.dtype == np.uint8
58
+ h, w = attnmap.shape
59
+ res = self.basergb.copy()
60
+ res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3
61
+ attn1 = attnmap.copy()[..., None] # h, w, 1
62
+ res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
63
+ return res
64
+
65
+
66
+ def rainbow_text(x, y, ls, lc, **kw):
67
+ """
68
+ Take a list of strings ``ls`` and colors ``lc`` and place them next to each
69
+ other, with text ls[i] being shown in color lc[i].
70
+
71
+ This example shows how to do both vertical and horizontal text, and will
72
+ pass all keyword arguments to plt.text, so you can set the font size,
73
+ family, etc.
74
+ """
75
+ t = plt.gca().transData
76
+ fig = plt.gcf()
77
+ plt.show()
78
+
79
+ # horizontal version
80
+ for s, c in zip(ls, lc):
81
+ text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw)
82
+ text.draw(fig.canvas.get_renderer())
83
+ ex = text.get_window_extent()
84
+ t = transforms.offset_copy(text._transform, x=ex.width, units="dots")
85
+
86
+ # #vertical version
87
+ # for s,c in zip(ls,lc):
88
+ # text = plt.text(x,y," "+s+" ",color=c, transform=t,
89
+ # rotation=90,va='bottom',ha='center',**kw)
90
+ # text.draw(fig.canvas.get_renderer())
91
+ # ex = text.get_window_extent()
92
+ # t = transforms.offset_copy(text._transform, y=ex.height, units='dots')
93
+
94
+
95
+ class COCOVisualizer:
96
+ def __init__(self, coco=None, tokenlizer=None) -> None:
97
+ self.coco = coco
98
+
99
+ def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"):
100
+ """
101
+ img: tensor(3, H, W)
102
+ tgt: make sure they are all on cpu.
103
+ must have items: 'image_id', 'boxes', 'size'
104
+ """
105
+ plt.figure(dpi=dpi)
106
+ plt.rcParams["font.size"] = "5"
107
+ ax = plt.gca()
108
+ img = renorm(img).permute(1, 2, 0)
109
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
110
+ # import ipdb; ipdb.set_trace()
111
+ ax.imshow(img)
112
+
113
+ self.addtgt(tgt)
114
+
115
+ if tgt is None:
116
+ image_id = 0
117
+ elif "image_id" not in tgt:
118
+ image_id = 0
119
+ else:
120
+ image_id = tgt["image_id"]
121
+
122
+ if caption is None:
123
+ savename = "{}/{}-{}.png".format(
124
+ savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
125
+ )
126
+ else:
127
+ savename = "{}/{}-{}-{}.png".format(
128
+ savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
129
+ )
130
+ print("savename: {}".format(savename))
131
+ os.makedirs(os.path.dirname(savename), exist_ok=True)
132
+ plt.savefig(savename)
133
+ plt.close()
134
+
135
+ def addtgt(self, tgt):
136
+ """ """
137
+ if tgt is None or not "boxes" in tgt:
138
+ ax = plt.gca()
139
+
140
+ if "caption" in tgt:
141
+ ax.set_title(tgt["caption"], wrap=True)
142
+
143
+ ax.set_axis_off()
144
+ return
145
+
146
+ ax = plt.gca()
147
+ H, W = tgt["size"]
148
+ numbox = tgt["boxes"].shape[0]
149
+
150
+ color = []
151
+ polygons = []
152
+ boxes = []
153
+ for box in tgt["boxes"].cpu():
154
+ unnormbbox = box * torch.Tensor([W, H, W, H])
155
+ unnormbbox[:2] -= unnormbbox[2:] / 2
156
+ [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
157
+ boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
158
+ poly = [
159
+ [bbox_x, bbox_y],
160
+ [bbox_x, bbox_y + bbox_h],
161
+ [bbox_x + bbox_w, bbox_y + bbox_h],
162
+ [bbox_x + bbox_w, bbox_y],
163
+ ]
164
+ np_poly = np.array(poly).reshape((4, 2))
165
+ polygons.append(Polygon(np_poly))
166
+ c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
167
+ color.append(c)
168
+
169
+ p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
170
+ ax.add_collection(p)
171
+ p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
172
+ ax.add_collection(p)
173
+
174
+ if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0:
175
+ assert (
176
+ len(tgt["strings_positive"]) == numbox
177
+ ), f"{len(tgt['strings_positive'])} = {numbox}, "
178
+ for idx, strlist in enumerate(tgt["strings_positive"]):
179
+ cate_id = int(tgt["labels"][idx])
180
+ _string = str(cate_id) + ":" + " ".join(strlist)
181
+ bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
182
+ # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
183
+ ax.text(
184
+ bbox_x,
185
+ bbox_y,
186
+ _string,
187
+ color="black",
188
+ bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
189
+ )
190
+
191
+ if "box_label" in tgt:
192
+ assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
193
+ for idx, bl in enumerate(tgt["box_label"]):
194
+ _string = str(bl)
195
+ bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
196
+ # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
197
+ ax.text(
198
+ bbox_x,
199
+ bbox_y,
200
+ _string,
201
+ color="black",
202
+ bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
203
+ )
204
+
205
+ if "caption" in tgt:
206
+ ax.set_title(tgt["caption"], wrap=True)
207
+ # plt.figure()
208
+ # rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(),
209
+ # ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black'])
210
+
211
+ if "attn" in tgt:
212
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
213
+ # import ipdb; ipdb.set_trace()
214
+ if isinstance(tgt["attn"], tuple):
215
+ tgt["attn"] = [tgt["attn"]]
216
+ for item in tgt["attn"]:
217
+ attn_map, basergb = item
218
+ attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3)
219
+ attn_map = (attn_map * 255).astype(np.uint8)
220
+ cm = ColorMap(basergb)
221
+ heatmap = cm(attn_map)
222
+ ax.imshow(heatmap)
223
+ ax.set_axis_off()
224
+
225
+ def showAnns(self, anns, draw_bbox=False):
226
+ """
227
+ Display the specified annotations.
228
+ :param anns (array of object): annotations to display
229
+ :return: None
230
+ """
231
+ if len(anns) == 0:
232
+ return 0
233
+ if "segmentation" in anns[0] or "keypoints" in anns[0]:
234
+ datasetType = "instances"
235
+ elif "caption" in anns[0]:
236
+ datasetType = "captions"
237
+ else:
238
+ raise Exception("datasetType not supported")
239
+ if datasetType == "instances":
240
+ ax = plt.gca()
241
+ ax.set_autoscale_on(False)
242
+ polygons = []
243
+ color = []
244
+ for ann in anns:
245
+ c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
246
+ if "segmentation" in ann:
247
+ if type(ann["segmentation"]) == list:
248
+ # polygon
249
+ for seg in ann["segmentation"]:
250
+ poly = np.array(seg).reshape((int(len(seg) / 2), 2))
251
+ polygons.append(Polygon(poly))
252
+ color.append(c)
253
+ else:
254
+ # mask
255
+ t = self.imgs[ann["image_id"]]
256
+ if type(ann["segmentation"]["counts"]) == list:
257
+ rle = maskUtils.frPyObjects(
258
+ [ann["segmentation"]], t["height"], t["width"]
259
+ )
260
+ else:
261
+ rle = [ann["segmentation"]]
262
+ m = maskUtils.decode(rle)
263
+ img = np.ones((m.shape[0], m.shape[1], 3))
264
+ if ann["iscrowd"] == 1:
265
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
266
+ if ann["iscrowd"] == 0:
267
+ color_mask = np.random.random((1, 3)).tolist()[0]
268
+ for i in range(3):
269
+ img[:, :, i] = color_mask[i]
270
+ ax.imshow(np.dstack((img, m * 0.5)))
271
+ if "keypoints" in ann and type(ann["keypoints"]) == list:
272
+ # turn skeleton into zero-based index
273
+ sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1
274
+ kp = np.array(ann["keypoints"])
275
+ x = kp[0::3]
276
+ y = kp[1::3]
277
+ v = kp[2::3]
278
+ for sk in sks:
279
+ if np.all(v[sk] > 0):
280
+ plt.plot(x[sk], y[sk], linewidth=3, color=c)
281
+ plt.plot(
282
+ x[v > 0],
283
+ y[v > 0],
284
+ "o",
285
+ markersize=8,
286
+ markerfacecolor=c,
287
+ markeredgecolor="k",
288
+ markeredgewidth=2,
289
+ )
290
+ plt.plot(
291
+ x[v > 1],
292
+ y[v > 1],
293
+ "o",
294
+ markersize=8,
295
+ markerfacecolor=c,
296
+ markeredgecolor=c,
297
+ markeredgewidth=2,
298
+ )
299
+
300
+ if draw_bbox:
301
+ [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"]
302
+ poly = [
303
+ [bbox_x, bbox_y],
304
+ [bbox_x, bbox_y + bbox_h],
305
+ [bbox_x + bbox_w, bbox_y + bbox_h],
306
+ [bbox_x + bbox_w, bbox_y],
307
+ ]
308
+ np_poly = np.array(poly).reshape((4, 2))
309
+ polygons.append(Polygon(np_poly))
310
+ color.append(c)
311
+
312
+ # p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
313
+ # ax.add_collection(p)
314
+ p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
315
+ ax.add_collection(p)
316
+ elif datasetType == "captions":
317
+ for ann in anns:
318
+ print(ann["caption"])
GroundingDINO/groundingdino/util/vl_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import List
4
+
5
+ import torch
6
+
7
+
8
+ def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
9
+ """construct a map such that positive_map[i,j] = True iff box i is associated to token j
10
+ Input:
11
+ - tokenized:
12
+ - input_ids: Tensor[1, ntokens]
13
+ - attention_mask: Tensor[1, ntokens]
14
+ - token_span: list with length num_boxes.
15
+ - each item: [start_idx, end_idx]
16
+ """
17
+ positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
18
+ for j, tok_list in enumerate(token_span):
19
+ for (beg, end) in tok_list:
20
+ beg_pos = tokenized.char_to_token(beg)
21
+ end_pos = tokenized.char_to_token(end - 1)
22
+ if beg_pos is None:
23
+ try:
24
+ beg_pos = tokenized.char_to_token(beg + 1)
25
+ if beg_pos is None:
26
+ beg_pos = tokenized.char_to_token(beg + 2)
27
+ except:
28
+ beg_pos = None
29
+ if end_pos is None:
30
+ try:
31
+ end_pos = tokenized.char_to_token(end - 2)
32
+ if end_pos is None:
33
+ end_pos = tokenized.char_to_token(end - 3)
34
+ except:
35
+ end_pos = None
36
+ if beg_pos is None or end_pos is None:
37
+ continue
38
+
39
+ assert beg_pos is not None and end_pos is not None
40
+ if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
41
+ positive_map[j, beg_pos] = 1
42
+ break
43
+ else:
44
+ positive_map[j, beg_pos : end_pos + 1].fill_(1)
45
+
46
+ return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
47
+
48
+
49
+ def build_captions_and_token_span(cat_list, force_lowercase):
50
+ """
51
+ Return:
52
+ captions: str
53
+ cat2tokenspan: dict
54
+ {
55
+ 'dog': [[0, 2]],
56
+ ...
57
+ }
58
+ """
59
+
60
+ cat2tokenspan = {}
61
+ captions = ""
62
+ for catname in cat_list:
63
+ class_name = catname
64
+ if force_lowercase:
65
+ class_name = class_name.lower()
66
+ if "/" in class_name:
67
+ class_name_list: List = class_name.strip().split("/")
68
+ class_name_list.append(class_name)
69
+ class_name: str = random.choice(class_name_list)
70
+
71
+ tokens_positive_i = []
72
+ subnamelist = [i.strip() for i in class_name.strip().split(" ")]
73
+ for subname in subnamelist:
74
+ if len(subname) == 0:
75
+ continue
76
+ if len(captions) > 0:
77
+ captions = captions + " "
78
+ strat_idx = len(captions)
79
+ end_idx = strat_idx + len(subname)
80
+ tokens_positive_i.append([strat_idx, end_idx])
81
+ captions = captions + subname
82
+
83
+ if len(tokens_positive_i) > 0:
84
+ captions = captions + " ."
85
+ cat2tokenspan[class_name] = tokens_positive_i
86
+
87
+ return captions, cat2tokenspan
88
+
89
+
90
+ def build_id2posspan_and_caption(category_dict: dict):
91
+ """Build id2pos_span and caption from category_dict
92
+
93
+ Args:
94
+ category_dict (dict): category_dict
95
+ """
96
+ cat_list = [item["name"].lower() for item in category_dict]
97
+ id2catname = {item["id"]: item["name"].lower() for item in category_dict}
98
+ caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
99
+ id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
100
+ return id2posspan, caption
GroundingDINO/groundingdino/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.1.0'
GroundingDINO/setup.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The IDEA Authors. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------------------------------
16
+ # Modified from
17
+ # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py
18
+ # https://github.com/facebookresearch/detectron2/blob/main/setup.py
19
+ # https://github.com/open-mmlab/mmdetection/blob/master/setup.py
20
+ # https://github.com/Oneflow-Inc/libai/blob/main/setup.py
21
+ # ------------------------------------------------------------------------------------------------
22
+
23
+ import glob
24
+ import os
25
+ import subprocess
26
+
27
+ import torch
28
+ from setuptools import find_packages, setup
29
+ from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
30
+
31
+ # groundingdino version info
32
+ version = "0.1.0"
33
+ package_name = "groundingdino"
34
+ cwd = os.path.dirname(os.path.abspath(__file__))
35
+
36
+
37
+ sha = "Unknown"
38
+ try:
39
+ sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
40
+ except Exception:
41
+ pass
42
+
43
+
44
+ def write_version_file():
45
+ version_path = os.path.join(cwd, "groundingdino", "version.py")
46
+ with open(version_path, "w") as f:
47
+ f.write(f"__version__ = '{version}'\n")
48
+ # f.write(f"git_version = {repr(sha)}\n")
49
+
50
+
51
+ requirements = ["torch", "torchvision"]
52
+
53
+ torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
54
+
55
+
56
+ def get_extensions():
57
+ this_dir = os.path.dirname(os.path.abspath(__file__))
58
+ extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
59
+
60
+ main_source = os.path.join(extensions_dir, "vision.cpp")
61
+ sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
62
+ source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
63
+ os.path.join(extensions_dir, "*.cu")
64
+ )
65
+
66
+ sources = [main_source] + sources
67
+
68
+ # We need these variables to build with CUDA when we create the Docker image
69
+ # It solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
70
+ # and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84 when running
71
+ # inside a Docker container.
72
+ am_i_docker = os.environ.get('AM_I_DOCKER', '').casefold() in ['true', '1', 't']
73
+ use_cuda = os.environ.get('BUILD_WITH_CUDA', '').casefold() in ['true', '1', 't']
74
+
75
+ extension = CppExtension
76
+
77
+ extra_compile_args = {"cxx": []}
78
+ define_macros = []
79
+
80
+ if (torch.cuda.is_available() and CUDA_HOME is not None) or \
81
+ (am_i_docker and use_cuda):
82
+ print("Compiling with CUDA")
83
+ extension = CUDAExtension
84
+ sources += source_cuda
85
+ define_macros += [("WITH_CUDA", None)]
86
+ extra_compile_args["nvcc"] = [
87
+ "-DCUDA_HAS_FP16=1",
88
+ "-D__CUDA_NO_HALF_OPERATORS__",
89
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
90
+ "-D__CUDA_NO_HALF2_OPERATORS__",
91
+ ]
92
+ else:
93
+ print("Compiling without CUDA")
94
+ define_macros += [("WITH_HIP", None)]
95
+ extra_compile_args["nvcc"] = []
96
+ return None
97
+
98
+ sources = [os.path.join(extensions_dir, s) for s in sources]
99
+ include_dirs = [extensions_dir]
100
+
101
+ ext_modules = [
102
+ extension(
103
+ "groundingdino._C",
104
+ sources,
105
+ include_dirs=include_dirs,
106
+ define_macros=define_macros,
107
+ extra_compile_args=extra_compile_args,
108
+ )
109
+ ]
110
+
111
+ return ext_modules
112
+
113
+
114
+ def parse_requirements(fname="requirements.txt", with_version=True):
115
+ """Parse the package dependencies listed in a requirements file but strips
116
+ specific versioning information.
117
+
118
+ Args:
119
+ fname (str): path to requirements file
120
+ with_version (bool, default=False): if True include version specs
121
+
122
+ Returns:
123
+ List[str]: list of requirements items
124
+
125
+ CommandLine:
126
+ python -c "import setup; print(setup.parse_requirements())"
127
+ """
128
+ import re
129
+ import sys
130
+ from os.path import exists
131
+
132
+ require_fpath = fname
133
+
134
+ def parse_line(line):
135
+ """Parse information from a line in a requirements text file."""
136
+ if line.startswith("-r "):
137
+ # Allow specifying requirements in other files
138
+ target = line.split(" ")[1]
139
+ for info in parse_require_file(target):
140
+ yield info
141
+ else:
142
+ info = {"line": line}
143
+ if line.startswith("-e "):
144
+ info["package"] = line.split("#egg=")[1]
145
+ elif "@git+" in line:
146
+ info["package"] = line
147
+ else:
148
+ # Remove versioning from the package
149
+ pat = "(" + "|".join([">=", "==", ">"]) + ")"
150
+ parts = re.split(pat, line, maxsplit=1)
151
+ parts = [p.strip() for p in parts]
152
+
153
+ info["package"] = parts[0]
154
+ if len(parts) > 1:
155
+ op, rest = parts[1:]
156
+ if ";" in rest:
157
+ # Handle platform specific dependencies
158
+ # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
159
+ version, platform_deps = map(str.strip, rest.split(";"))
160
+ info["platform_deps"] = platform_deps
161
+ else:
162
+ version = rest # NOQA
163
+ info["version"] = (op, version)
164
+ yield info
165
+
166
+ def parse_require_file(fpath):
167
+ with open(fpath, "r") as f:
168
+ for line in f.readlines():
169
+ line = line.strip()
170
+ if line and not line.startswith("#"):
171
+ for info in parse_line(line):
172
+ yield info
173
+
174
+ def gen_packages_items():
175
+ if exists(require_fpath):
176
+ for info in parse_require_file(require_fpath):
177
+ parts = [info["package"]]
178
+ if with_version and "version" in info:
179
+ parts.extend(info["version"])
180
+ if not sys.version.startswith("3.4"):
181
+ # apparently package_deps are broken in 3.4
182
+ platform_deps = info.get("platform_deps")
183
+ if platform_deps is not None:
184
+ parts.append(";" + platform_deps)
185
+ item = "".join(parts)
186
+ yield item
187
+
188
+ packages = list(gen_packages_items())
189
+ return packages
190
+
191
+
192
+ if __name__ == "__main__":
193
+ print(f"Building wheel {package_name}-{version}")
194
+
195
+ with open("LICENSE", "r", encoding="utf-8") as f:
196
+ license = f.read()
197
+
198
+ write_version_file()
199
+
200
+ setup(
201
+ name="groundingdino",
202
+ version="0.1.0",
203
+ author="International Digital Economy Academy, Shilong Liu",
204
+ url="https://github.com/IDEA-Research/GroundingDINO",
205
+ description="open-set object detector",
206
+ license=license,
207
+ install_requires=parse_requirements("requirements.txt"),
208
+ packages=find_packages(
209
+ exclude=(
210
+ "configs",
211
+ "tests",
212
+ )
213
+ ),
214
+ ext_modules=get_extensions(),
215
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
216
+ )
README.md CHANGED
@@ -1,188 +1,21 @@
1
  ---
2
- license: mit
3
- title: RAM_Tag2Text
4
- sdk: gradio
5
  emoji: 🐠
6
- colorFrom: blue
7
- colorTo: yellow
 
 
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
 
12
 
13
- # :label: Recognize Anything: A Strong Image Tagging Model & Tag2Text: Guiding Vision-Language Model via Image Tagging
14
-
15
- Official PyTorch Implementation of the <a href="https://recognize-anything.github.io/">Recognize Anything Model (RAM)</a> and the <a href="https://tag2text.github.io/">Tag2Text Model</a>.
16
-
17
- - RAM is a strong image tagging model, which can recognize any common category with high accuracy.
18
- - Tag2Text is an efficient and controllable vision-language model with tagging guidance.
19
-
20
-
21
- When combined with localization models ([Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything)), Tag2Text and RAM form a strong and general pipeline for visual semantic analysis.
22
-
23
- ![](./images/ram_grounded_sam.jpg)
24
-
25
- ## :sun_with_face: Helpful Tutorial
26
-
27
-
28
- - :apple: [[Access RAM Homepage](https://recognize-anything.github.io/)]
29
- - :grapes: [[Access Tag2Text Homepage](https://tag2text.github.io/)]
30
- - :sunflower: [[Read RAM arXiv Paper](https://arxiv.org/abs/2306.03514)]
31
- - :rose: [[Read Tag2Text arXiv Paper](https://arxiv.org/abs/2303.05657)]
32
- - :mushroom: [[Try our Tag2Text web Demo! 🤗](https://huggingface.co/spaces/xinyu1205/Tag2Text)]
33
-
34
-
35
-
36
- ## :bulb: Highlight
37
- **Recognition and localization are two foundation computer vision tasks.**
38
- - **The Segment Anything Model (SAM)** excels in **localization capabilities**, while it falls short when it comes to **recognition tasks**.
39
- - **The Recognize Anything Model (RAM) and Tag2Text** exhibits **exceptional recognition abilities**, in terms of **both accuracy and scope**.
40
-
41
- <p align="center">
42
- <table class="tg">
43
- <tr>
44
- <td class="tg-c3ow"><img src="images/localization_and_recognition.jpg" align="center" width="800" ></td>
45
- </tr>
46
- </table>
47
- </p>
48
-
49
-
50
- <details close>
51
- <summary><font size="4">
52
- Tag2Text for Vision-Language Tasks.
53
- </font></summary>
54
-
55
- - **Tagging.** Without manual annotations, Tag2Text achieves **superior** image tag recognition ability of [**3,429**](./data/tag_list.txt) commonly human-used categories.
56
- - **Efficient.** Tagging guidance effectively enhances the performance of vision-language models on both **generation-based** and **alignment-based** tasks.
57
- - **Controllable.** Tag2Text permits users to input **desired tags**, providing the flexibility in composing corresponding texts based on the input tags.
58
-
59
-
60
- <p align="center">
61
- <table class="tg">
62
- <tr>
63
- <td class="tg-c3ow"><img src="images/tag2text_framework.png" align="center" width="800" ></td>
64
- </tr>
65
- </table>
66
- </p>
67
- </details>
68
-
69
-
70
- <details close>
71
- <summary><font size="4">
72
- Advancements of RAM on Tag2Text.
73
- </font></summary>
74
-
75
- - **Accuracy.** RAM utilizes a data engine to generate additional annotations and clean incorrect ones, resulting higher accuracy compared to Tag2Text.
76
- - **Scope.** Tag2Text recognizes 3,400+ fixed tags. RAM upgrades the number to 6,400+, covering more valuable categories. With open-set capability, RAM is feasible to recognize any common category.
77
-
78
-
79
- </details>
80
-
81
-
82
- ## :sparkles: Highlight Projects with other Models
83
- - [Tag2Text/RAM with Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) is trong and general pipeline for visual semantic analysis, which can automatically **recognize**, detect, and segment for an image!
84
- - [Ask-Anything](https://github.com/OpenGVLab/Ask-Anything) is a multifunctional video question answering tool. Tag2Text provides powerful tagging and captioning capabilities as a fundamental component.
85
- - [Prompt-can-anything](https://github.com/positive666/Prompt-Can-Anything) is a gradio web library that integrates SOTA multimodal large models, including Tag2text as the core model for graphic understanding
86
-
87
-
88
-
89
-
90
- ## :fire: News
91
-
92
- - **`2023/06/07`**: We release the [Recognize Anything Model (RAM)](https://recognize-anything.github.io/), a strong image tagging model!
93
- - **`2023/06/05`**: Tag2Text is combined with [Prompt-can-anything](https://github.com/OpenGVLab/Ask-Anything).
94
- - **`2023/05/20`**: Tag2Text is combined with [VideoChat](https://github.com/OpenGVLab/Ask-Anything).
95
- - **`2023/04/20`**: We marry Tag2Text with with [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything).
96
- - **`2023/04/10`**: Code and checkpoint is available Now!
97
- - **`2023/03/14`**: [Tag2Text web demo 🤗](https://huggingface.co/spaces/xinyu1205/Tag2Text) is available on Hugging Face Space!
98
-
99
-
100
-
101
-
102
-
103
-
104
- ## :writing_hand: TODO
105
-
106
- - [x] Release Tag2Text demo.
107
- - [x] Release checkpoints.
108
- - [x] Release inference code.
109
- - [ ] Release RAM demo and checkpoints (until June 14th at the latest).
110
- - [ ] Release training codes (until August 1st at the latest).
111
- - [ ] Release training datasets (until August 1st at the latest).
112
-
113
-
114
-
115
- ## :toolbox: Checkpoints
116
-
117
- <!-- insert a table -->
118
- <table>
119
- <thead>
120
- <tr style="text-align: right;">
121
- <th></th>
122
- <th>name</th>
123
- <th>backbone</th>
124
- <th>Data</th>
125
- <th>Illustration</th>
126
- <th>Checkpoint</th>
127
- </tr>
128
- </thead>
129
- <tbody>
130
- <tr>
131
- <th>1</th>
132
- <td>Tag2Text-Swin</td>
133
- <td>Swin-Base</td>
134
- <td>COCO, VG, SBU, CC-3M, CC-12M</td>
135
- <td>Demo version with comprehensive captions.</td>
136
- <td><a href="https://huggingface.co/spaces/xinyu1205/Tag2Text/blob/main/tag2text_swin_14m.pth">Download link</a></td>
137
- </tr>
138
- </tbody>
139
- </table>
140
-
141
-
142
- ## :running: Tag2Text Inference
143
-
144
- 1. Install the dependencies, run:
145
-
146
- <pre/>pip install -r requirements.txt</pre>
147
-
148
- 2. Download Tag2Text pretrained checkpoints.
149
-
150
- 3. Get the tagging and captioning results:
151
- <pre/>
152
- python inference.py --image images/1641173_2291260800.jpg \
153
- --pretrained pretrained/tag2text_swin_14m.pth
154
- </pre>
155
- Or get the tagging and sepcifed captioning results (optional):
156
- <pre/>python inference.py --image images/1641173_2291260800.jpg \
157
- --pretrained pretrained/tag2text_swin_14m.pth \
158
- --specified-tags "cloud,sky"</pre>
159
-
160
-
161
- ## :black_nib: Citation
162
- If you find our work to be useful for your research, please consider citing.
163
-
164
- ```
165
- @misc{zhang2023recognize,
166
- title={Recognize Anything: A Strong Image Tagging Model},
167
- author={Youcai Zhang and Xinyu Huang and Jinyu Ma and Zhaoyang Li and Zhaochuan Luo and Yanchun Xie and Yuzhuo Qin and Tong Luo and Yaqian Li and Shilong Liu and Yandong Guo and Lei Zhang},
168
- year={2023},
169
- eprint={2306.03514},
170
- archivePrefix={arXiv},
171
- primaryClass={cs.CV}
172
- }
173
-
174
- @article{huang2023tag2text,
175
- title={Tag2Text: Guiding Vision-Language Model via Image Tagging},
176
- author={Huang, Xinyu and Zhang, Youcai and Ma, Jinyu and Tian, Weiwei and Feng, Rui and Zhang, Yuejie and Li, Yaqian and Guo, Yandong and Zhang, Lei},
177
- journal={arXiv preprint arXiv:2303.05657},
178
- year={2023}
179
- }
180
- ```
181
-
182
- ## :hearts: Acknowledgements
183
- This work is done with the help of the amazing code base of [BLIP](https://github.com/salesforce/BLIP), thanks very much!
184
-
185
- We also want to thank @Cheng Rui @Shilong Liu @Ren Tianhe for their help in [marrying Tag2Text with Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything).
186
 
 
187
 
 
188
 
 
 
1
  ---
2
+ title: Recognize Detect Segment Anything
 
 
3
  emoji: 🐠
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Recognize Anything Model + Grounded-SAM DEMO
14
 
15
+ Please refer to these homepages for details of these projects:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ - Recognize Anthing Model: [https://recognize-anything.github.io/](https://recognize-anything.github.io/)
18
 
19
+ - Tag2Text Model: [https://https://tag2text.github.io/](https://https://tag2text.github.io/)
20
 
21
+ - Grounded-SAM: [https://github.com/IDEA-Research/Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)
app.py CHANGED
@@ -1,209 +1,387 @@
1
- import numpy as np
2
- import random
3
-
4
- import torch
5
- import torchvision.transforms as transforms
6
-
7
- from PIL import Image
8
- from models.tag2text import tag2text_caption, ram
9
-
10
- import gradio as gr
11
-
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
-
14
- image_size = 384
15
-
16
- normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
17
- std=[0.229, 0.224, 0.225
18
- ])
19
- transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
20
-
21
- #######Tag2Text Model
22
- pretrained = 'tag2text_swin_14m.pth'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- model_tag2text = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
25
 
26
- model_tag2text.eval()
27
- model_tag2text = model_tag2text.to(device)
28
 
29
 
30
- #######RAM Model
31
- pretrained = 'ram_swin_large_14m.pth'
 
 
 
32
 
33
- model_ram = ram(pretrained=pretrained, image_size=image_size, vit='swin_l' )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- model_ram.eval()
36
- model_ram = model_ram.to(device)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- def inference(raw_image, model_n , input_tag):
40
- raw_image = raw_image.resize((image_size, image_size))
 
 
 
41
 
42
- image = transform(raw_image).unsqueeze(0).to(device)
43
- if model_n == 'Recognize Anything Model':
44
- model = model_ram
45
- with torch.no_grad():
46
- tags, tags_chinese = model.generate_tag(image)
47
- return tags[0],tags_chinese[0], 'none'
48
- else:
49
- model = model_tag2text
50
- model.threshold = 0.68
51
- if input_tag == '' or input_tag == 'none' or input_tag == 'None':
52
- input_tag_list = None
53
- else:
54
- input_tag_list = []
55
- input_tag_list.append(input_tag.replace(',',' | '))
56
- with torch.no_grad():
57
-
58
-
59
- caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
60
- if input_tag_list == None:
61
- tag_1 = tag_predict
62
- tag_2 = ['none']
63
- else:
64
- _, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
65
- tag_2 = tag_predict
66
-
67
- return tag_1[0],'none',caption[0]
68
-
69
-
70
- def build_gui():
71
-
72
- description = """
73
- <center><strong><font size='10'>Recognize Anything Model</font></strong></center>
74
- <br>
75
- Welcome to the Recognize Anything Model (RAM) and Tag2Text Model demo! <br><br>
76
- <li>
77
- <b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese outputs of the image tags</b>!
78
- </li>
79
- <li>
80
- <b>Tag2Text Model:</b> Upload your image to get the <b>tags</b> and <b>caption</b> of the image.
81
- Optional: You can also input specified tags to get the corresponding caption.
82
- </li>
83
- """ # noqa
84
-
85
- article = """
86
- <p style='text-align: center'>
87
- RAM and Tag2Text is training on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
88
- <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
89
- |
90
- <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
91
- |
92
- <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a>
93
- </p>
94
- """ # noqa
95
-
96
- def inference_with_ram(img):
97
- res = inference(img, "Recognize Anything Model", None)
98
- return res[0], res[1]
99
-
100
- def inference_with_t2t(img, input_tags):
101
- res = inference(img, "Tag2Text Model", input_tags)
102
- return res[0], res[2]
103
-
104
- with gr.Blocks(title="Recognize Anything Model") as demo:
105
- ###############
106
- # components
107
- ###############
108
- gr.HTML(description)
109
-
110
- with gr.Tab(label="Recognize Anything Model"):
111
- with gr.Row():
112
- with gr.Column():
113
- ram_in_img = gr.Image(type="pil")
114
- with gr.Row():
115
- ram_btn_run = gr.Button(value="Run")
116
- ram_btn_clear = gr.Button(value="Clear")
117
- with gr.Column():
118
- ram_out_tag = gr.Textbox(label="Tags")
119
- ram_out_biaoqian = gr.Textbox(label="标签")
120
- gr.Examples(
121
- examples=[
122
- ["images/demo1.jpg"],
123
- ["images/demo2.jpg"],
124
- ["images/demo4.jpg"],
125
- ],
126
  fn=inference_with_ram,
127
- inputs=[ram_in_img],
128
- outputs=[ram_out_tag, ram_out_biaoqian],
129
- cache_examples=True
130
  )
131
-
132
- with gr.Tab(label="Tag2Text Model"):
133
- with gr.Row():
134
- with gr.Column():
135
- t2t_in_img = gr.Image(type="pil")
136
- t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
137
- with gr.Row():
138
- t2t_btn_run = gr.Button(value="Run")
139
- t2t_btn_clear = gr.Button(value="Clear")
140
- with gr.Column():
141
- t2t_out_tag = gr.Textbox(label="Tags")
142
- t2t_out_cap = gr.Textbox(label="Caption")
143
- gr.Examples(
144
- examples=[
145
- ["images/demo4.jpg", ""],
146
- ["images/demo4.jpg", "power line"],
147
- ["images/demo4.jpg", "track, train"],
148
- ],
149
  fn=inference_with_t2t,
150
- inputs=[t2t_in_img, t2t_in_tag],
151
- outputs=[t2t_out_tag, t2t_out_cap],
152
- cache_examples=True
153
  )
154
 
155
- gr.HTML(article)
156
-
157
- ###############
158
- # events
159
- ###############
160
- # run inference
161
- ram_btn_run.click(
162
- fn=inference_with_ram,
163
- inputs=[ram_in_img],
164
- outputs=[ram_out_tag, ram_out_biaoqian]
165
- )
166
- t2t_btn_run.click(
167
- fn=inference_with_t2t,
168
- inputs=[t2t_in_img, t2t_in_tag],
169
- outputs=[t2t_out_tag, t2t_out_cap]
170
- )
171
-
172
- # # images of two image panels should keep the same
173
- # # and clear old outputs when image changes
174
- # # slow due to internet latency when deployed on huggingface, comment out
175
- # def sync_img(v):
176
- # return [gr.update(value=v)] + [gr.update(value="")] * 4
177
-
178
- # ram_in_img.upload(fn=sync_img, inputs=[ram_in_img], outputs=[
179
- # t2t_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
180
- # ])
181
- # ram_in_img.clear(fn=sync_img, inputs=[ram_in_img], outputs=[
182
- # t2t_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
183
- # ])
184
- # t2t_in_img.clear(fn=sync_img, inputs=[t2t_in_img], outputs=[
185
- # ram_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
186
- # ])
187
- # t2t_in_img.upload(fn=sync_img, inputs=[t2t_in_img], outputs=[
188
- # ram_in_img, ram_out_tag, ram_out_biaoqian, t2t_out_tag, t2t_out_cap
189
- # ])
190
-
191
- # clear all
192
- def clear_all():
193
- return [gr.update(value=None)] * 2 + [gr.update(value="")] * 5
194
-
195
- ram_btn_clear.click(fn=clear_all, inputs=[], outputs=[
196
- ram_in_img, t2t_in_img,
197
- ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
198
- ])
199
- t2t_btn_clear.click(fn=clear_all, inputs=[], outputs=[
200
- ram_in_img, t2t_in_img,
201
- ram_out_tag, ram_out_biaoqian, t2t_in_tag, t2t_out_tag, t2t_out_cap
202
- ])
203
-
204
- return demo
205
 
 
 
 
206
 
207
- if __name__ == "__main__":
208
- demo = build_gui()
209
- demo.launch(enable_queue=True)
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # setup Grouded-Segment-Anything
6
+ # building GroundingDINO requires torch but imports it before installing,
7
+ # so directly installing in requirements.txt causes dependency error.
8
+ # 1. build with "-e" option to keep the bin file in ./GroundingDINO/groundingdino/, rather than in site-package dir.
9
+ os.system("pip install -e ./GroundingDINO/")
10
+ # 2. for unknown reason, "import groundingdino" will fill due to unable to find the module, even after installing.
11
+ # add ./GroundingDINO/ to PATH, so package "groundingdino" can be imported.
12
+ sys.path.append(str(Path(__file__).parent / "GroundingDINO"))
13
+
14
+ import random # noqa: E402
15
+
16
+ import cv2 # noqa: E402
17
+ import groundingdino.datasets.transforms as T # noqa: E402
18
+ import numpy as np # noqa: E402
19
+ import torch # noqa: E402
20
+ import torchvision # noqa: E402
21
+ import torchvision.transforms as TS # noqa: E402
22
+ from groundingdino.models import build_model # noqa: E402
23
+ from groundingdino.util.slconfig import SLConfig # noqa: E402
24
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # noqa: E402
25
+ from PIL import Image, ImageDraw, ImageFont # noqa: E402
26
+ from ram import inference_ram # noqa: E402
27
+ from ram import inference_tag2text # noqa: E402
28
+ from ram.models import ram # noqa: E402
29
+ from ram.models import tag2text_caption # noqa: E402
30
+ from segment_anything import SamPredictor, build_sam # noqa: E402
31
+
32
+
33
+ # args
34
+ config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
35
+ ram_checkpoint = "./ram_swin_large_14m.pth"
36
+ tag2text_checkpoint = "./tag2text_swin_14m.pth"
37
+ grounded_checkpoint = "./groundingdino_swint_ogc.pth"
38
+ sam_checkpoint = "./sam_vit_h_4b8939.pth"
39
+ box_threshold = 0.25
40
+ text_threshold = 0.2
41
+ iou_threshold = 0.5
42
+ device = "cpu"
43
+
44
+
45
+ def load_model(model_config_path, model_checkpoint_path, device):
46
+ args = SLConfig.fromfile(model_config_path)
47
+ args.device = device
48
+ model = build_model(args)
49
+ checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
50
+ load_res = model.load_state_dict(
51
+ clean_state_dict(checkpoint["model"]), strict=False)
52
+ print(load_res)
53
+ _ = model.eval()
54
+ return model
55
+
56
+
57
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, device="cpu"):
58
+ caption = caption.lower()
59
+ caption = caption.strip()
60
+ if not caption.endswith("."):
61
+ caption = caption + "."
62
+ model = model.to(device)
63
+ image = image.to(device)
64
+ with torch.no_grad():
65
+ outputs = model(image[None], captions=[caption])
66
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
67
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
68
+ logits.shape[0]
69
+
70
+ # filter output
71
+ logits_filt = logits.clone()
72
+ boxes_filt = boxes.clone()
73
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
74
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
75
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
76
+ logits_filt.shape[0]
77
+
78
+ # get phrase
79
+ tokenlizer = model.tokenizer
80
+ tokenized = tokenlizer(caption)
81
+ # build pred
82
+ pred_phrases = []
83
+ scores = []
84
+ for logit, box in zip(logits_filt, boxes_filt):
85
+ pred_phrase = get_phrases_from_posmap(
86
+ logit > text_threshold, tokenized, tokenlizer)
87
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
88
+ scores.append(logit.max().item())
89
+
90
+ return boxes_filt, torch.Tensor(scores), pred_phrases
91
+
92
+
93
+ def draw_mask(mask, draw, random_color=False):
94
+ if random_color:
95
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153)
96
+ else:
97
+ color = (30, 144, 255, 153)
98
 
99
+ nonzero_coords = np.transpose(np.nonzero(mask))
100
 
101
+ for coord in nonzero_coords:
102
+ draw.point(coord[::-1], fill=color)
103
 
104
 
105
+ def draw_box(box, draw, label):
106
+ # random color
107
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
108
+ line_width = int(max(4, min(20, 0.006*max(draw.im.size))))
109
+ draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=line_width)
110
 
111
+ if label:
112
+ font_path = os.path.join(
113
+ cv2.__path__[0], 'qt', 'fonts', 'DejaVuSans.ttf')
114
+ font_size = int(max(12, min(60, 0.02*max(draw.im.size))))
115
+ font = ImageFont.truetype(font_path, size=font_size)
116
+ if hasattr(font, "getbbox"):
117
+ bbox = draw.textbbox((box[0], box[1]), str(label), font)
118
+ else:
119
+ w, h = draw.textsize(str(label), font)
120
+ bbox = (box[0], box[1], w + box[0], box[1] + h)
121
+ draw.rectangle(bbox, fill=color)
122
+ draw.text((box[0], box[1]), str(label), fill="white", font=font)
123
+
124
+ draw.text((box[0], box[1]), label, font=font)
125
+
126
+
127
+ @torch.no_grad()
128
+ def inference(
129
+ raw_image, specified_tags, do_det_seg,
130
+ tagging_model_type, tagging_model, grounding_dino_model, sam_model
131
+ ):
132
+ print(f"Start processing, image size {raw_image.size}")
133
+ raw_image = raw_image.convert("RGB")
134
+
135
+ # run tagging model
136
+ normalize = TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
137
+ transform = TS.Compose([
138
+ TS.Resize((384, 384)),
139
+ TS.ToTensor(),
140
+ normalize
141
+ ])
142
+
143
+ image = raw_image.resize((384, 384))
144
+ image = transform(image).unsqueeze(0).to(device)
145
+
146
+ # Currently ", " is better for detecting single tags
147
+ # while ". " is a little worse in some case
148
+ if tagging_model_type == "RAM":
149
+ res = inference_ram(image, tagging_model)
150
+ tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
151
+ tags_chinese = res[1].strip(' ').replace(' ', ' ').replace(' |', ',')
152
+ print("Tags: ", tags)
153
+ print("图像标签: ", tags_chinese)
154
+ else:
155
+ res = inference_tag2text(image, tagging_model, specified_tags)
156
+ tags = res[0].strip(' ').replace(' ', ' ').replace(' |', ',')
157
+ caption = res[2]
158
+ print(f"Tags: {tags}")
159
+ print(f"Caption: {caption}")
160
+
161
+ # return
162
+ if not do_det_seg:
163
+ if tagging_model_type == "RAM":
164
+ return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), None
165
+ else:
166
+ return tags.replace(", ", " | "), caption, None
167
+
168
+ # run groundingDINO
169
+ transform = T.Compose([
170
+ T.RandomResize([800], max_size=1333),
171
+ T.ToTensor(),
172
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
173
+ ])
174
+
175
+ image, _ = transform(raw_image, None) # 3, h, w
176
+
177
+ boxes_filt, scores, pred_phrases = get_grounding_output(
178
+ grounding_dino_model, image, tags, box_threshold, text_threshold, device=device
179
+ )
180
+ print("GroundingDINO finished")
181
+
182
+ # run SAM
183
+ image = np.asarray(raw_image)
184
+ sam_model.set_image(image)
185
+
186
+ size = raw_image.size
187
+ H, W = size[1], size[0]
188
+ for i in range(boxes_filt.size(0)):
189
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
190
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
191
+ boxes_filt[i][2:] += boxes_filt[i][:2]
192
+
193
+ boxes_filt = boxes_filt.cpu()
194
+ # use NMS to handle overlapped boxes
195
+ print(f"Before NMS: {boxes_filt.shape[0]} boxes")
196
+ nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
197
+ boxes_filt = boxes_filt[nms_idx]
198
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
199
+ print(f"After NMS: {boxes_filt.shape[0]} boxes")
200
+
201
+ transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
202
+
203
+ masks, _, _ = sam_model.predict_torch(
204
+ point_coords=None,
205
+ point_labels=None,
206
+ boxes=transformed_boxes.to(device),
207
+ multimask_output=False,
208
+ )
209
+ print("SAM finished")
210
+
211
+ # draw output image
212
+ mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
213
+
214
+ mask_draw = ImageDraw.Draw(mask_image)
215
+ for mask in masks:
216
+ draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True)
217
+
218
+ image_draw = ImageDraw.Draw(raw_image)
219
+
220
+ for box, label in zip(boxes_filt, pred_phrases):
221
+ draw_box(box, image_draw, label)
222
+
223
+ out_image = raw_image.convert('RGBA')
224
+ out_image.alpha_composite(mask_image)
225
+
226
+ # return
227
+ if tagging_model_type == "RAM":
228
+ return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), out_image
229
+ else:
230
+ return tags.replace(", ", " | "), caption, out_image
231
 
 
 
232
 
233
+ if __name__ == "__main__":
234
+ import gradio as gr
235
+
236
+ # load RAM
237
+ ram_model = ram(pretrained=ram_checkpoint, image_size=384, vit='swin_l')
238
+ ram_model.eval()
239
+ ram_model = ram_model.to(device)
240
+
241
+ # load Tag2Text
242
+ delete_tag_index = [] # filter out attributes and action categories which are difficult to grounding
243
+ for i in range(3012, 3429):
244
+ delete_tag_index.append(i)
245
+
246
+ tag2text_model = tag2text_caption(pretrained=tag2text_checkpoint,
247
+ image_size=384,
248
+ vit='swin_b',
249
+ delete_tag_index=delete_tag_index)
250
+ tag2text_model.threshold = 0.64 # we reduce the threshold to obtain more tags
251
+ tag2text_model.eval()
252
+ tag2text_model = tag2text_model.to(device)
253
+
254
+ # load groundingDINO
255
+ grounding_dino_model = load_model(config_file, grounded_checkpoint, device=device)
256
+
257
+ # load SAM
258
+ sam_model = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
259
+
260
+ # build GUI
261
+ def build_gui():
262
+
263
+ description = """
264
+ <center><strong><font size='10'>Recognize Anything Model + Grounded-SAM</font></strong></center>
265
+ <br>
266
+ Welcome to the RAM/Tag2Text + Grounded-SAM demo! <br><br>
267
+ <li>
268
+ <b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
269
+ </li>
270
+ <li>
271
+ <b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>!
272
+ (Optional: Specify tags to get the corresponding caption.)
273
+ </li>
274
+ <li>
275
+ <b>Grounded-SAM:</b> Tick the checkbox to get <b>boxes</b> and <b>masks</b> of tags!
276
+ </li>
277
+ <br>
278
+ Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo!
279
+ """ # noqa
280
+
281
+ article = """
282
+ <p style='text-align: center'>
283
+ RAM and Tag2Text are trained on open-source datasets, and we are persisting in refining and iterating upon it.<br/>
284
+ Grounded-SAM is a combination of Grounding DINO and SAM aming to detect and segment anything with text inputs.<br/>
285
+ <a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a>
286
+ |
287
+ <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a>
288
+ |
289
+ <a href='https://github.com/IDEA-Research/Grounded-Segment-Anything' target='_blank'>Grounded-Segment-Anything</a>
290
+ </p>
291
+ """ # noqa
292
+
293
+ def inference_with_ram(img, do_det_seg):
294
+ return inference(
295
+ img, None, do_det_seg,
296
+ "RAM", ram_model, grounding_dino_model, sam_model
297
+ )
298
 
299
+ def inference_with_t2t(img, input_tags, do_det_seg):
300
+ return inference(
301
+ img, input_tags, do_det_seg,
302
+ "Tag2Text", tag2text_model, grounding_dino_model, sam_model
303
+ )
304
 
305
+ with gr.Blocks(title="Recognize Anything Model") as demo:
306
+ ###############
307
+ # components
308
+ ###############
309
+ gr.HTML(description)
310
+
311
+ with gr.Tab(label="Recognize Anything Model"):
312
+ with gr.Row():
313
+ with gr.Column():
314
+ ram_in_img = gr.Image(type="pil")
315
+ ram_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
316
+ with gr.Row():
317
+ ram_btn_run = gr.Button(value="Run")
318
+ ram_btn_clear = gr.ClearButton()
319
+ with gr.Column():
320
+ ram_out_img = gr.Image(type="pil")
321
+ ram_out_tag = gr.Textbox(label="Tags")
322
+ ram_out_biaoqian = gr.Textbox(label="标签")
323
+ gr.Examples(
324
+ examples=[
325
+ ["images/demo1.jpg", True],
326
+ ["images/demo2.jpg", True],
327
+ ["images/demo4.jpg", True],
328
+ ],
329
+ fn=inference_with_ram,
330
+ inputs=[ram_in_img, ram_opt_det_seg],
331
+ outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img],
332
+ cache_examples=True
333
+ )
334
+
335
+ with gr.Tab(label="Tag2Text Model"):
336
+ with gr.Row():
337
+ with gr.Column():
338
+ t2t_in_img = gr.Image(type="pil")
339
+ t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
340
+ t2t_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
341
+ with gr.Row():
342
+ t2t_btn_run = gr.Button(value="Run")
343
+ t2t_btn_clear = gr.ClearButton()
344
+ with gr.Column():
345
+ t2t_out_img = gr.Image(type="pil")
346
+ t2t_out_tag = gr.Textbox(label="Tags")
347
+ t2t_out_cap = gr.Textbox(label="Caption")
348
+ gr.Examples(
349
+ examples=[
350
+ ["images/demo4.jpg", "", True],
351
+ ["images/demo4.jpg", "power line", False],
352
+ ["images/demo4.jpg", "track, train", False],
353
+ ],
354
+ fn=inference_with_t2t,
355
+ inputs=[t2t_in_img, t2t_in_tag, t2t_opt_det_seg],
356
+ outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img],
357
+ cache_examples=True
358
+ )
359
+
360
+ gr.HTML(article)
361
+
362
+ ###############
363
+ # events
364
+ ###############
365
+ # run inference
366
+ ram_btn_run.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  fn=inference_with_ram,
368
+ inputs=[ram_in_img, ram_opt_det_seg],
369
+ outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img]
 
370
  )
371
+ t2t_btn_run.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  fn=inference_with_t2t,
373
+ inputs=[t2t_in_img, t2t_in_tag, t2t_opt_det_seg],
374
+ outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img]
 
375
  )
376
 
377
+ # hide or show image output
378
+ ram_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[ram_opt_det_seg], outputs=[ram_out_img])
379
+ t2t_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[t2t_opt_det_seg], outputs=[t2t_out_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
+ # clear
382
+ ram_btn_clear.add([ram_in_img, ram_out_img, ram_out_tag, ram_out_biaoqian])
383
+ t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_img, t2t_out_tag, t2t_out_cap])
384
 
385
+ return demo
386
+
387
+ build_gui().launch(enable_queue=True)
configs/med_config.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertModel"
4
- ],
5
- "attention_probs_dropout_prob": 0.1,
6
- "hidden_act": "gelu",
7
- "hidden_dropout_prob": 0.1,
8
- "hidden_size": 768,
9
- "initializer_range": 0.02,
10
- "intermediate_size": 3072,
11
- "layer_norm_eps": 1e-12,
12
- "max_position_embeddings": 512,
13
- "model_type": "bert",
14
- "num_attention_heads": 12,
15
- "num_hidden_layers": 12,
16
- "pad_token_id": 0,
17
- "type_vocab_size": 2,
18
- "vocab_size": 30524,
19
- "encoder_width": 768,
20
- "add_cross_attention": true
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/q2l_config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertModel"
4
- ],
5
- "attention_probs_dropout_prob": 0.1,
6
- "hidden_act": "gelu",
7
- "hidden_dropout_prob": 0.1,
8
- "hidden_size": 768,
9
- "initializer_range": 0.02,
10
- "intermediate_size": 3072,
11
- "layer_norm_eps": 1e-12,
12
- "max_position_embeddings": 512,
13
- "model_type": "bert",
14
- "num_attention_heads": 4,
15
- "num_hidden_layers": 2,
16
- "pad_token_id": 0,
17
- "type_vocab_size": 2,
18
- "vocab_size": 30522,
19
- "encoder_width": 768,
20
- "add_cross_attention": true,
21
- "add_tag_cross_attention": false
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/swin/config_swinB_224.json DELETED
@@ -1,10 +0,0 @@
1
- {
2
- "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
- "vision_width": 1024,
4
- "image_res": 224,
5
- "window_size": 7,
6
- "embed_dim": 128,
7
- "depths": [ 2, 2, 18, 2 ],
8
- "num_heads": [ 4, 8, 16, 32 ]
9
- }
10
-
 
 
 
 
 
 
 
 
 
 
 
configs/swin/config_swinB_384.json DELETED
@@ -1,10 +0,0 @@
1
- {
2
- "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
- "vision_width": 1024,
4
- "image_res": 384,
5
- "window_size": 12,
6
- "embed_dim": 128,
7
- "depths": [ 2, 2, 18, 2 ],
8
- "num_heads": [ 4, 8, 16, 32 ]
9
- }
10
-
 
 
 
 
 
 
 
 
 
 
 
configs/swin/config_swinB_480.json DELETED
@@ -1,9 +0,0 @@
1
- {
2
- "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
- "vision_width": 1024,
4
- "image_res": 480,
5
- "window_size": 15,
6
- "embed_dim": 128,
7
- "depths": [ 2, 2, 18, 2 ],
8
- "num_heads": [ 4, 8, 16, 32 ]
9
- }
 
 
 
 
 
 
 
 
 
 
configs/swin/config_swinB_576.json DELETED
@@ -1,9 +0,0 @@
1
- {
2
- "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
- "vision_width": 1024,
4
- "image_res": 576,
5
- "window_size": 18,
6
- "embed_dim": 128,
7
- "depths": [ 2, 2, 18, 2 ],
8
- "num_heads": [ 4, 8, 16, 32 ]
9
- }