returnzeros commited on
Commit
4d8e7a6
·
verified ·
1 Parent(s): 6927973

Upload 108 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. ACT_DP_multitask/README.md +16 -0
  3. ACT_DP_multitask/base.yaml +71 -0
  4. ACT_DP_multitask/detr/LICENSE +201 -0
  5. ACT_DP_multitask/detr/README.md +9 -0
  6. ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc +0 -0
  7. ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc +0 -0
  8. ACT_DP_multitask/detr/detr.egg-info/PKG-INFO +17 -0
  9. ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt +37 -0
  10. ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt +1 -0
  11. ACT_DP_multitask/detr/detr.egg-info/top_level.txt +2 -0
  12. ACT_DP_multitask/detr/main.py +763 -0
  13. ACT_DP_multitask/detr/models/__init__.py +60 -0
  14. ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc +0 -0
  15. ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc +0 -0
  16. ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc +0 -0
  17. ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc +0 -0
  18. ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc +0 -0
  19. ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc +0 -0
  20. ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc +0 -0
  21. ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc +0 -0
  22. ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc +0 -0
  23. ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc +0 -0
  24. ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc +0 -0
  25. ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc +0 -0
  26. ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc +0 -0
  27. ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc +0 -0
  28. ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc +0 -0
  29. ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc +0 -0
  30. ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc +0 -0
  31. ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc +0 -0
  32. ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc +0 -0
  33. ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc +0 -0
  34. ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc +0 -0
  35. ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc +0 -0
  36. ACT_DP_multitask/detr/models/backbone.py +209 -0
  37. ACT_DP_multitask/detr/models/detr_vae.py +0 -0
  38. ACT_DP_multitask/detr/models/detr_vae_nfp.py +523 -0
  39. ACT_DP_multitask/detr/models/mask_former/__init__.py +19 -0
  40. ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc +0 -0
  41. ACT_DP_multitask/detr/models/mask_former/config.py +85 -0
  42. ACT_DP_multitask/detr/models/mask_former/mask_former_model.py +304 -0
  43. ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py +5 -0
  44. ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py +1 -0
  45. ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py +768 -0
  46. ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py +187 -0
  47. ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py +1 -0
  48. ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py +119 -0
  49. ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py +243 -0
  50. ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py +294 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ACT_DP_multitask/detr/models/mr_mg/media/model.gif filter=lfs diff=lfs merge=lfs -text
ACT_DP_multitask/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Install
2
+ ```
3
+ cd policy/ACT-DP-TP
4
+ cd detr
5
+ pip install -e .
6
+ cd ..
7
+ cd Cosmos-Tokenizer
8
+ pip install -e .
9
+ #upload policy/ACT-DP-TP/Cosmos-Tokenizer/pretrained_ckpts
10
+ ```
11
+ ### Command
12
+ ```
13
+ #data_dir: policy/ACT-DP-TP/data_zarr
14
+ cd policy/ACT-DP-TP
15
+ bash scripts/act_dp_tp/train.sh bottle_adjust 300 20 20 0
16
+ ```
ACT_DP_multitask/base.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common:
2
+ # The number of historical images
3
+ img_history_size: 2
4
+ # The number of future actions to predict
5
+ action_chunk_size: 64
6
+ # The number of cameras to be used in the model
7
+ num_cameras: 3
8
+ # Dimension for state/action, we use the same space for both state and action
9
+ # This MUST be equal to configs/state_vec.py
10
+ state_dim: 128
11
+
12
+
13
+ dataset:
14
+ # We will extract the data from raw dataset
15
+ # and store them in the disk buffer by producer
16
+ # When training, we will read the data
17
+ # randomly from the buffer by consumer
18
+ # The producer will replace the data which has been
19
+ # read by the consumer with new data
20
+
21
+ # The path to the buffer (at least 400GB)
22
+ buf_path: /path/to/buffer
23
+ # The number of chunks in the buffer
24
+ buf_num_chunks: 512
25
+ # The number of samples (step rather than episode) in each chunk
26
+ buf_chunk_size: 512
27
+
28
+ # We will filter the episodes with length less than `epsd_len_thresh_low`
29
+ epsd_len_thresh_low: 32
30
+ # For those more than `epsd_len_thresh_high`,
31
+ # we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
32
+ # to better balance the training datasets
33
+ epsd_len_thresh_high: 2048
34
+ # How to fit the image size
35
+ image_aspect_ratio: pad
36
+ # Maximum number of language tokens
37
+ tokenizer_max_length: 1024
38
+
39
+ model:
40
+ # Config for condition adpators
41
+ lang_adaptor: mlp2x_gelu
42
+ img_adaptor: mlp2x_gelu
43
+ state_adaptor: mlp3x_gelu
44
+ lang_token_dim: 4096
45
+ img_token_dim: 1152
46
+ # Dim of action or proprioception vector
47
+ # A `state` refers to an action or a proprioception vector
48
+ state_token_dim: 128
49
+ # Config for RDT structure
50
+ rdt:
51
+ # 1B: num_head 32 hidden_size 2048
52
+ hidden_size: 2048
53
+ depth: 28
54
+ num_heads: 32
55
+ cond_pos_embed_type: multimodal
56
+ # For noise scheduler
57
+ noise_scheduler:
58
+ type: ddpm
59
+ num_train_timesteps: 1000
60
+ num_inference_timesteps: 5
61
+ beta_schedule: squaredcos_cap_v2 # Critical choice
62
+ prediction_type: sample
63
+ clip_sample: False
64
+ # For EMA (params averaging)
65
+ # We do not use EMA currently
66
+ ema:
67
+ update_after_step: 0
68
+ inv_gamma: 1.0
69
+ power: 0.75
70
+ min_value: 0.0
71
+ max_value: 0.9999
ACT_DP_multitask/detr/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 - present, Facebook, Inc
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ACT_DP_multitask/detr/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
2
+
3
+ @article{Carion2020EndtoEndOD,
4
+ title={End-to-End Object Detection with Transformers},
5
+ author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
6
+ journal={ArXiv},
7
+ year={2020},
8
+ volume={abs/2005.12872}
9
+ }
ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc ADDED
Binary file (15.9 kB). View file
 
ACT_DP_multitask/detr/detr.egg-info/PKG-INFO ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: detr
3
+ Version: 0.0.0
4
+ License: MIT License
5
+ License-File: LICENSE
6
+ Dynamic: description
7
+ Dynamic: license
8
+
9
+ This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
10
+
11
+ @article{Carion2020EndtoEndOD,
12
+ title={End-to-End Object Detection with Transformers},
13
+ author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
14
+ journal={ArXiv},
15
+ year={2020},
16
+ volume={abs/2005.12872}
17
+ }
ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ detr.egg-info/PKG-INFO
5
+ detr.egg-info/SOURCES.txt
6
+ detr.egg-info/dependency_links.txt
7
+ detr.egg-info/top_level.txt
8
+ models/__init__.py
9
+ models/backbone.py
10
+ models/detr_vae.py
11
+ models/detr_vae_nfp.py
12
+ models/position_encoding.py
13
+ models/transformer.py
14
+ models/vision_transformer.py
15
+ models/mask_former/__init__.py
16
+ models/mask_former/config.py
17
+ models/mask_former/mask_former_model.py
18
+ models/mask_former/test_time_augmentation.py
19
+ models/mask_former/modeling/__init__.py
20
+ models/mask_former/modeling/criterion.py
21
+ models/mask_former/modeling/matcher.py
22
+ models/mask_former/modeling/backbone/__init__.py
23
+ models/mask_former/modeling/backbone/swin.py
24
+ models/mask_former/modeling/heads/__init__.py
25
+ models/mask_former/modeling/heads/mask_former_head.py
26
+ models/mask_former/modeling/heads/per_pixel_baseline.py
27
+ models/mask_former/modeling/heads/pixel_decoder.py
28
+ models/mask_former/modeling/transformer/__init__.py
29
+ models/mask_former/modeling/transformer/position_encoding.py
30
+ models/mask_former/modeling/transformer/transformer.py
31
+ models/mask_former/modeling/transformer/transformer_predictor.py
32
+ models/mask_former/utils/__init__.py
33
+ models/mask_former/utils/misc.py
34
+ util/__init__.py
35
+ util/box_ops.py
36
+ util/misc.py
37
+ util/plot_utils.py
ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
ACT_DP_multitask/detr/detr.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ models
2
+ util
ACT_DP_multitask/detr/main.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import argparse
3
+ from pathlib import Path
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ from .models import *
8
+
9
+ import IPython
10
+
11
+ e = IPython.embed
12
+
13
+
14
+ def get_args_parser():
15
+ parser = argparse.ArgumentParser("Set transformer detector", add_help=False)
16
+ parser.add_argument("--ckpt_path", type=str, default='policy/ACT_DP_multitask/checkpoints/real_fintune_50_2000/act_dp')
17
+ parser.add_argument("--eval_ckpts", default=0, type=int, help="eval_ckpts")
18
+ parser.add_argument("--eval_video_log", action="store_true")
19
+ parser.add_argument("--action_interval", default=1, type=int)
20
+ parser.add_argument("--lr", default=1e-4, type=float) # will be overridden
21
+ parser.add_argument("--lr_backbone", default=1e-5, type=float) # will be overridden
22
+ parser.add_argument(
23
+ "--lr_schedule_type", default="constant", type=str, help="lr_schedule_type"
24
+ )
25
+ parser.add_argument(
26
+ "--num_episodes", type=int, help="num_epochs", default=0, required=False
27
+ )
28
+ parser.add_argument("--batch_size", default=2, type=int) # not used
29
+ parser.add_argument(
30
+ "--samples_per_epoch",
31
+ default=1,
32
+ type=int,
33
+ help="samples_per_epoch",
34
+ required=False,
35
+ )
36
+ parser.add_argument("--weight_decay", default=1e-4, type=float)
37
+ parser.add_argument("--epochs", default=300, type=int) # not used
38
+ parser.add_argument("--lr_drop", default=200, type=int) # not used
39
+ parser.add_argument(
40
+ "--clip_max_norm",
41
+ default=0.1,
42
+ type=float, # not used
43
+ help="gradient clipping max norm",
44
+ )
45
+ parser.add_argument("--norm_type", default="meanstd", type=str, help="norm_type")
46
+ parser.add_argument(
47
+ "--num_train_steps", default=50, type=int, help="num_train_steps"
48
+ )
49
+ parser.add_argument(
50
+ "--num_inference_steps", default=10, type=int, help="num_inference_steps"
51
+ )
52
+ parser.add_argument(
53
+ "--schedule_type", default="DDIM", type=str, help="scheduler_type"
54
+ )
55
+ parser.add_argument(
56
+ "--imitate_weight", default=1, type=int, help="imitate Weight", required=False
57
+ )
58
+ parser.add_argument(
59
+ "--prediction_type", default="sample", type=str, help="prediction_type"
60
+ )
61
+ parser.add_argument(
62
+ "--beta_schedule", default="squaredcos_cap_v2", type=str, help="prediction_type"
63
+ )
64
+ parser.add_argument(
65
+ "--diffusion_timestep_type",
66
+ default="cat",
67
+ type=str,
68
+ help="diffusion_timestep_type, cat or add, how to combine timestep",
69
+ )
70
+ parser.add_argument(
71
+ "--condition_type",
72
+ default="cross_attention",
73
+ type=str,
74
+ help="diffusion_condition_type, cross_attention or adaLN, how to combine observation condition",
75
+ )
76
+ parser.add_argument("--attention_type", default="v0", help="decoder attention type")
77
+ parser.add_argument(
78
+ "--causal_mask", action="store_true", help="use causal mask for diffusion"
79
+ )
80
+ parser.add_argument("--loss_type", default="l2", type=str, help="loss_type")
81
+ parser.add_argument(
82
+ "--disable_vae_latent",
83
+ action="store_true",
84
+ help="Use VAE latent space by default",
85
+ )
86
+ parser.add_argument(
87
+ "--disable_resnet",
88
+ action="store_true",
89
+ help="Use resnet to encode obs image by default",
90
+ )
91
+ parser.add_argument(
92
+ "--disable_scale",
93
+ action="store_true",
94
+ help="scale model up",
95
+ )
96
+ parser.add_argument(
97
+ "--inference_num_queries",
98
+ default=0,
99
+ type=int,
100
+ help="inference_num_queries",
101
+ required=False,
102
+ ) # predict_frame
103
+ parser.add_argument(
104
+ "--disable_resize", action="store_true", help="if resize jpeg image"
105
+ )
106
+ parser.add_argument(
107
+ "--share_decoder", action="store_true", help="jpeg and action share decoder"
108
+ )
109
+ parser.add_argument(
110
+ "--resize_rate",
111
+ default=1,
112
+ type=int,
113
+ help="resize rate for pixel prediction",
114
+ required=False,
115
+ )
116
+ parser.add_argument(
117
+ "--image_downsample_rate",
118
+ default=1,
119
+ type=int,
120
+ help="image_downsample_rate",
121
+ required=False,
122
+ )
123
+ parser.add_argument(
124
+ "--temporal_downsample_rate",
125
+ default=1,
126
+ type=int,
127
+ help="temporal_downsample_rate",
128
+ required=False,
129
+ )
130
+ # Model parameters external
131
+ parser.add_argument("--test_num", default=50, type=int, help="test_num")
132
+ parser.add_argument("--save_episode", action="store_true")
133
+ parser.add_argument(
134
+ "--depth_mode",
135
+ default="None",
136
+ type=str,
137
+ help="use depth/depth+coordinate/None. ALL/Single/None",
138
+ )
139
+ parser.add_argument(
140
+ "--pc_mode", default="pc_camera", type=str, help="pc_world/pc_camera"
141
+ )
142
+ parser.add_argument(
143
+ "--disable_multi_view", action="store_true", help="Use multi-view rgb images"
144
+ )
145
+ # * Backbone
146
+ parser.add_argument(
147
+ "--backbone",
148
+ default="resnet18",
149
+ type=str, # will be overridden
150
+ help="Name of the convolutional backbone to use",
151
+ )
152
+ parser.add_argument(
153
+ "--dilation",
154
+ action="store_true",
155
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)",
156
+ )
157
+ parser.add_argument(
158
+ "--position_embedding",
159
+ default="sine",
160
+ type=str,
161
+ choices=("sine", "learned"),
162
+ help="Type of positional embedding to use on top of the image features",
163
+ )
164
+ parser.add_argument(
165
+ "--camera_names",
166
+ default=[],
167
+ type=list, # will be overridden
168
+ help="A list of camera names",
169
+ )
170
+
171
+ # * Transformer
172
+ parser.add_argument(
173
+ "--enc_layers",
174
+ default=4,
175
+ type=int, # will be overridden
176
+ help="Number of encoding layers in the transformer",
177
+ )
178
+ parser.add_argument(
179
+ "--dec_layers",
180
+ default=6,
181
+ type=int, # will be overridden
182
+ help="Number of decoding layers in the transformer",
183
+ )
184
+ parser.add_argument(
185
+ "--dim_feedforward",
186
+ default=2048,
187
+ type=int, # will be overridden
188
+ help="Intermediate size of the feedforward layers in the transformer blocks",
189
+ )
190
+ parser.add_argument(
191
+ "--hidden_dim",
192
+ default=256,
193
+ type=int, # will be overridden
194
+ help="Size of the embeddings (dimension of the transformer)",
195
+ )
196
+ parser.add_argument(
197
+ "--dropout", default=0.1, type=float, help="Dropout applied in the transformer"
198
+ )
199
+ parser.add_argument(
200
+ "--nheads",
201
+ default=8,
202
+ type=int, # will be overridden
203
+ help="Number of attention heads inside the transformer's attentions",
204
+ )
205
+ parser.add_argument(
206
+ "--num_queries",
207
+ default=400,
208
+ type=int, # will be overridden
209
+ help="Number of query slots",
210
+ )
211
+ parser.add_argument("--pre_norm", action="store_true")
212
+
213
+ # # * Segmentation
214
+ parser.add_argument(
215
+ "--masks",
216
+ action="store_true",
217
+ help="Train segmentation head if the flag is provided",
218
+ )
219
+
220
+ # repeat args in imitate_episodes just to avoid error. Will not be used
221
+ parser.add_argument("--eval", action="store_true")
222
+ parser.add_argument("--onscreen_render", action="store_true")
223
+ parser.add_argument(
224
+ "--ckpt_dir", action="store", type=str, help="ckpt_dir", required=False
225
+ )
226
+ parser.add_argument(
227
+ "--policy_class",
228
+ action="store",
229
+ type=str,
230
+ help="policy_class, capitalize",
231
+ required=False,
232
+ )
233
+ parser.add_argument(
234
+ "--task_name", action="store", type=str, help="task_name", required=False
235
+ )
236
+ parser.add_argument("--seed", action="store", type=int, help="seed", required=False)
237
+ parser.add_argument(
238
+ "--num_epochs", action="store", type=int, help="num_epochs", required=False
239
+ )
240
+ parser.add_argument(
241
+ "--kl_weight", action="store", type=int, help="KL Weight", required=False
242
+ )
243
+ parser.add_argument(
244
+ "--save_epoch",
245
+ action="store",
246
+ type=int,
247
+ help="save_epoch",
248
+ default=500,
249
+ required=False,
250
+ )
251
+ parser.add_argument(
252
+ "--chunk_size", action="store", type=int, help="chunk_size", required=False
253
+ )
254
+ parser.add_argument(
255
+ "--history_step", default=0, type=int, help="history_step", required=False
256
+ )
257
+ parser.add_argument(
258
+ "--predict_frame", default=0, type=int, help="predict_frame", required=False
259
+ )
260
+ # add image_width and image_height
261
+ parser.add_argument(
262
+ "--image_width", default=320, type=int, help="image_width", required=False
263
+ )
264
+ parser.add_argument(
265
+ "--image_height", default=240, type=int, help="image_height", required=False
266
+ )
267
+ parser.add_argument(
268
+ "--predict_only_last", action="store_true"
269
+ ) # only predict the last #predict_frame frame
270
+ parser.add_argument("--temporal_agg", action="store_true")
271
+ # visual tokenizer
272
+ parser.add_argument(
273
+ "--tokenizer_model_type",
274
+ default="DV",
275
+ type=str,
276
+ help="tokenizer_model_type, DV,CV,DI,CI",
277
+ )
278
+ parser.add_argument(
279
+ "--tokenizer_model_temporal_rate",
280
+ default=8,
281
+ type=int,
282
+ help="tokenizer_model_temporal_rate, 4,8",
283
+ )
284
+ parser.add_argument(
285
+ "--tokenizer_model_spatial_rate",
286
+ default=16,
287
+ type=int,
288
+ help="tokenizer_model_spatial_rate, 8,16",
289
+ )
290
+ parser.add_argument(
291
+ "--tokenizer_model_name",
292
+ default="Cosmos-Tokenizer-DV4x8x8",
293
+ type=str,
294
+ help="tokenizer_model_name",
295
+ )
296
+ parser.add_argument(
297
+ "--prediction_weight",
298
+ default=1,
299
+ type=float,
300
+ help="pred token Weight",
301
+ required=False,
302
+ )
303
+ parser.add_argument(
304
+ "--token_dim", default=6, type=int, help="token_dim", required=False
305
+ ) # token_pe_type
306
+ parser.add_argument(
307
+ "--patch_size", default=5, type=int, help="patch_size", required=False
308
+ ) # token_pe_type
309
+ parser.add_argument(
310
+ "--token_pe_type",
311
+ default="learned",
312
+ type=str,
313
+ help="token_pe_type",
314
+ required=False,
315
+ )
316
+ parser.add_argument("--nf", action="store_true")
317
+ parser.add_argument("--pretrain", action="store_true", required=False)
318
+ parser.add_argument("--is_wandb", action="store_true")
319
+ parser.add_argument("--mae", action="store_true")
320
+ # parser.add_argument('--seg', action='store_true')
321
+ # parser.add_argument('--seg_next', action='store_true')
322
+
323
+ # parameters for distributed training
324
+ parser.add_argument(
325
+ "--resume",
326
+ default="",
327
+ type=str,
328
+ metavar="PATH",
329
+ help="path to latest checkpoint (default: none)",
330
+ )
331
+ parser.add_argument(
332
+ "--world-size",
333
+ default=-1,
334
+ type=int,
335
+ help="number of nodes for distributed training",
336
+ )
337
+ parser.add_argument(
338
+ "--rank", default=-1, type=int, help="node rank for distributed training"
339
+ )
340
+ parser.add_argument(
341
+ "--dist-url",
342
+ default="tcp://224.66.41.62:23456",
343
+ type=str,
344
+ help="url used to set up distributed training",
345
+ )
346
+ parser.add_argument(
347
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
348
+ )
349
+ # parser.add_argument(
350
+ # "--seed", default=None, type=int, help="seed for initializing training. "
351
+ # )
352
+ parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
353
+ parser.add_argument(
354
+ "--multiprocessing-distributed",
355
+ action="store_true",
356
+ help="Use multi-processing distributed training to launch "
357
+ "N processes per node, which has N GPUs. This is the "
358
+ "fastest way to use PyTorch for either single node or "
359
+ "multi node data parallel training",
360
+ )
361
+ parser.add_argument(
362
+ "-j",
363
+ "--workers",
364
+ default=32,
365
+ type=int,
366
+ metavar="N",
367
+ help="number of data loading workers (default: 32)",
368
+ )
369
+
370
+ return parser
371
+
372
+
373
+ def build_ACT_model_and_optimizer(args_override):
374
+ parser = argparse.ArgumentParser(
375
+ "DETR training and evaluation script", parents=[get_args_parser()]
376
+ )
377
+ args = parser.parse_args()
378
+
379
+ for k, v in args_override.items():
380
+ setattr(args, k, v)
381
+
382
+ if args_override["segmentation"]:
383
+ model = build_ACT_Seg_model(args)
384
+ else:
385
+ model = build_ACT_model(args)
386
+ model.cuda()
387
+
388
+ param_dicts = [
389
+ {
390
+ "params": [
391
+ p
392
+ for n, p in model.named_parameters()
393
+ if "backbone" not in n and p.requires_grad
394
+ ]
395
+ },
396
+ {
397
+ "params": [
398
+ p
399
+ for n, p in model.named_parameters()
400
+ if "backbone" in n and p.requires_grad
401
+ ],
402
+ "lr": args.lr_backbone,
403
+ },
404
+ ]
405
+ optimizer = torch.optim.AdamW(
406
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
407
+ )
408
+
409
+ return model, optimizer
410
+
411
+
412
+ def build_ACTDiffusion_model_and_optimizer(args_override):
413
+ parser = argparse.ArgumentParser(
414
+ "DETR training and evaluation script", parents=[get_args_parser()]
415
+ )
416
+ args = parser.parse_args()
417
+ for k, v in args_override.items():
418
+ setattr(args, k, v)
419
+ # print('args',args) # get
420
+ model = build_ACTDiffusion_model(args)
421
+ model.cuda()
422
+
423
+ param_dicts = [
424
+ {
425
+ "params": [
426
+ p
427
+ for n, p in model.named_parameters()
428
+ if "backbone" not in n and p.requires_grad
429
+ ]
430
+ },
431
+ {
432
+ "params": [
433
+ p
434
+ for n, p in model.named_parameters()
435
+ if "backbone" in n and p.requires_grad
436
+ ],
437
+ "lr": args.lr_backbone,
438
+ },
439
+ ]
440
+ optimizer = torch.optim.AdamW(
441
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
442
+ )
443
+
444
+ return model, optimizer
445
+
446
+
447
+ def build_ACTDiffusion_tactile_model_and_optimizer(args_override):
448
+ parser = argparse.ArgumentParser(
449
+ "DETR training and evaluation script", parents=[get_args_parser()]
450
+ )
451
+ args = parser.parse_args()
452
+ for k, v in args_override.items():
453
+ setattr(args, k, v)
454
+ # print('args',args) # get
455
+ model = build_ACTDiffusion_tactile_model(args)
456
+ model.cuda()
457
+
458
+ param_dicts = [
459
+ {
460
+ "params": [
461
+ p
462
+ for n, p in model.named_parameters()
463
+ if "backbone" not in n and p.requires_grad
464
+ ]
465
+ },
466
+ {
467
+ "params": [
468
+ p
469
+ for n, p in model.named_parameters()
470
+ if "backbone" in n and p.requires_grad
471
+ ],
472
+ "lr": args.lr_backbone,
473
+ },
474
+ ]
475
+ optimizer = torch.optim.AdamW(
476
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
477
+ )
478
+
479
+ return model, optimizer
480
+
481
+
482
+ def build_diffusion_tp_model_and_optimizer(args_override):
483
+ parser = argparse.ArgumentParser(
484
+ "DETR training and evaluation script", parents=[get_args_parser()]
485
+ )
486
+ args = parser.parse_args()
487
+ for k, v in args_override.items():
488
+ setattr(args, k, v)
489
+ # print('args',args) # get
490
+ model = build_ACTDiffusion_tp_model(args)
491
+ model.cuda()
492
+
493
+ return model # , optimizer
494
+
495
+
496
+ def build_diffusion_pp_model_and_optimizer(args_override):
497
+ parser = argparse.ArgumentParser(
498
+ "DETR training and evaluation script", parents=[get_args_parser()]
499
+ )
500
+ args = parser.parse_args()
501
+ for k, v in args_override.items():
502
+ setattr(args, k, v)
503
+ # print('args',args) # get
504
+ model = build_ACTDiffusion_pp_model(args)
505
+ model.cuda()
506
+
507
+ return model
508
+
509
+
510
+ # discard
511
+
512
+
513
+ def build_ACT_NF_model_and_optimizer(args_override):
514
+ parser = argparse.ArgumentParser(
515
+ "DETR training and evaluation script", parents=[get_args_parser()]
516
+ )
517
+ args = parser.parse_args()
518
+
519
+ for k, v in args_override.items():
520
+ setattr(args, k, v)
521
+
522
+ model = build_ACT_NF_model(args)
523
+ model.cuda()
524
+
525
+ param_dicts = [
526
+ {
527
+ "params": [
528
+ p
529
+ for n, p in model.named_parameters()
530
+ if "backbone" not in n and p.requires_grad
531
+ ]
532
+ },
533
+ {
534
+ "params": [
535
+ p
536
+ for n, p in model.named_parameters()
537
+ if "backbone" in n and p.requires_grad
538
+ ],
539
+ "lr": args.lr_backbone,
540
+ },
541
+ ]
542
+ optimizer = torch.optim.AdamW(
543
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
544
+ )
545
+
546
+ return model, optimizer
547
+
548
+
549
+ def build_ACT_Dino_model_and_optimizer(args_override):
550
+ parser = argparse.ArgumentParser(
551
+ "DETR training and evaluation script", parents=[get_args_parser()]
552
+ )
553
+ args = parser.parse_args()
554
+
555
+ for k, v in args_override.items():
556
+ setattr(args, k, v)
557
+
558
+ model = build_ACT_dino_model(args)
559
+ model.cuda()
560
+
561
+ param_dicts = [
562
+ {
563
+ "params": [
564
+ p
565
+ for n, p in model.named_parameters()
566
+ if "backbone" not in n and p.requires_grad
567
+ ]
568
+ },
569
+ {
570
+ "params": [
571
+ p
572
+ for n, p in model.named_parameters()
573
+ if "backbone" in n and p.requires_grad
574
+ ],
575
+ "lr": args.lr_backbone,
576
+ },
577
+ ]
578
+ optimizer = torch.optim.AdamW(
579
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
580
+ )
581
+
582
+ return model, optimizer
583
+
584
+
585
+ def build_ACT_jpeg_model_and_optimizer(args_override):
586
+ parser = argparse.ArgumentParser(
587
+ "DETR training and evaluation script", parents=[get_args_parser()]
588
+ )
589
+ args = parser.parse_args()
590
+
591
+ for k, v in args_override.items():
592
+ setattr(args, k, v)
593
+
594
+ model = build_ACT_jpeg_model(args)
595
+ model.cuda()
596
+
597
+ param_dicts = [
598
+ {
599
+ "params": [
600
+ p
601
+ for n, p in model.named_parameters()
602
+ if "backbone" not in n and p.requires_grad
603
+ ]
604
+ },
605
+ {
606
+ "params": [
607
+ p
608
+ for n, p in model.named_parameters()
609
+ if "backbone" in n and p.requires_grad
610
+ ],
611
+ "lr": args.lr_backbone,
612
+ },
613
+ ]
614
+ optimizer = torch.optim.AdamW(
615
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
616
+ )
617
+
618
+ return model, optimizer
619
+
620
+
621
+ def build_ACT_jpeg_diffusion_model_and_optimizer(args_override):
622
+ parser = argparse.ArgumentParser(
623
+ "DETR training and evaluation script", parents=[get_args_parser()]
624
+ )
625
+ args = parser.parse_args()
626
+
627
+ for k, v in args_override.items():
628
+ setattr(args, k, v)
629
+
630
+ model = build_ACT_jpeg_diffusion_model(args)
631
+ model.cuda()
632
+
633
+ param_dicts = [
634
+ {
635
+ "params": [
636
+ p
637
+ for n, p in model.named_parameters()
638
+ if "backbone" not in n and p.requires_grad
639
+ ]
640
+ },
641
+ {
642
+ "params": [
643
+ p
644
+ for n, p in model.named_parameters()
645
+ if "backbone" in n and p.requires_grad
646
+ ],
647
+ "lr": args.lr_backbone,
648
+ },
649
+ ]
650
+ optimizer = torch.optim.AdamW(
651
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
652
+ )
653
+
654
+ return model, optimizer
655
+
656
+
657
+ def build_ACT_jpeg_diffusion_seperate_model_and_optimizer(args_override):
658
+ parser = argparse.ArgumentParser(
659
+ "DETR training and evaluation script", parents=[get_args_parser()]
660
+ )
661
+ args = parser.parse_args()
662
+
663
+ for k, v in args_override.items():
664
+ setattr(args, k, v)
665
+
666
+ model = build_ACT_jpeg_diffusion_seperate_model(args)
667
+ model.cuda()
668
+
669
+ param_dicts = [
670
+ {
671
+ "params": [
672
+ p
673
+ for n, p in model.named_parameters()
674
+ if "backbone" not in n and p.requires_grad
675
+ ]
676
+ },
677
+ {
678
+ "params": [
679
+ p
680
+ for n, p in model.named_parameters()
681
+ if "backbone" in n and p.requires_grad
682
+ ],
683
+ "lr": args.lr_backbone,
684
+ },
685
+ ]
686
+ optimizer = torch.optim.AdamW(
687
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
688
+ )
689
+
690
+ return model, optimizer
691
+
692
+
693
+ def build_nf_diffusion_seperate_model_and_optimizer(args_override):
694
+ parser = argparse.ArgumentParser(
695
+ "DETR training and evaluation script", parents=[get_args_parser()]
696
+ )
697
+ args = parser.parse_args()
698
+
699
+ for k, v in args_override.items():
700
+ setattr(args, k, v)
701
+
702
+ model = build_nf_diffusion_seperate_model(args)
703
+ model.cuda()
704
+
705
+ param_dicts = [
706
+ {
707
+ "params": [
708
+ p
709
+ for n, p in model.named_parameters()
710
+ if "backbone" not in n and p.requires_grad
711
+ ]
712
+ },
713
+ {
714
+ "params": [
715
+ p
716
+ for n, p in model.named_parameters()
717
+ if "backbone" in n and p.requires_grad
718
+ ],
719
+ "lr": args.lr_backbone,
720
+ },
721
+ ]
722
+ optimizer = torch.optim.AdamW(
723
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
724
+ )
725
+
726
+ return model, optimizer
727
+
728
+
729
+ def build_CNNMLP_model_and_optimizer(args_override):
730
+ parser = argparse.ArgumentParser(
731
+ "DETR training and evaluation script", parents=[get_args_parser()]
732
+ )
733
+ args = parser.parse_args()
734
+
735
+ for k, v in args_override.items():
736
+ setattr(args, k, v)
737
+
738
+ model = build_CNNMLP_model(args)
739
+ model.cuda()
740
+
741
+ param_dicts = [
742
+ {
743
+ "params": [
744
+ p
745
+ for n, p in model.named_parameters()
746
+ if "backbone" not in n and p.requires_grad
747
+ ]
748
+ },
749
+ {
750
+ "params": [
751
+ p
752
+ for n, p in model.named_parameters()
753
+ if "backbone" in n and p.requires_grad
754
+ ],
755
+ "lr": args.lr_backbone,
756
+ },
757
+ ]
758
+ optimizer = torch.optim.AdamW(
759
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
760
+ )
761
+
762
+ return model, optimizer
763
+
ACT_DP_multitask/detr/models/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ from .detr_vae import build as build_vae
3
+ from .detr_vae import build_seg as build_vae_seg
4
+ from .detr_vae_nfp import build as build_vae_nfp
5
+ from .detr_vae import build_cnnmlp as build_cnnmlp
6
+ from .detr_vae import build_dino as build_dino
7
+ from .detr_vae import build_jpeg as build_jpeg
8
+ from .detr_vae import build_jpeg_diffusion as build_jpeg_diffusion
9
+ from .detr_vae import build_jpeg_diffusion_seperate as build_jpeg_diffusion_seperate
10
+ from .detr_vae import build_nf_diffusion_seperate as build_nf_diffusion_seperate
11
+ from .detr_vae import build_diffusion as build_diffusion
12
+ from .detr_vae import build_diffusion_tp as build_diffusion_tp
13
+ from .detr_vae import build_diffusion_tp_with_dual_visual_token as build_diffusion_tp_with_dual_visual_token
14
+ from .detr_vae import build_diffusion_pp as build_diffusion_pp
15
+ from .detr_vae import build_diffusion_tactile as build_diffusion_tactile
16
+
17
+ def build_ACT_model(args):
18
+ return build_vae(args)
19
+
20
+ def build_CNNMLP_model(args):
21
+ return build_cnnmlp(args)
22
+
23
+ def build_ACTDiffusion_model(args):
24
+ return build_diffusion(args)
25
+
26
+ def build_ACTDiffusion_tactile_model(args):
27
+ return build_diffusion_tactile(args)
28
+
29
+ def build_ACTDiffusion_tp_model(args):
30
+ if args.diffusion_timestep_type == 'vis_cat': # HARDCODE whether use tokenizer feature for decoder & action prediction
31
+ print('Using dual visual token for decoder and action prediction')
32
+ return build_diffusion_tp_with_dual_visual_token(args)
33
+ else:
34
+ return build_diffusion_tp(args)
35
+
36
+ def build_ACTDiffusion_pp_model(args):
37
+ return build_diffusion_pp(args)
38
+
39
+ # discard
40
+ def build_ACT_NF_model(args):
41
+ return build_vae_nfp(args)
42
+
43
+ def build_ACT_Seg_model(args):
44
+ return build_vae_seg(args)
45
+
46
+ def build_ACT_dino_model(args):
47
+ return build_dino(args)
48
+
49
+ def build_ACT_jpeg_model(args):
50
+ return build_jpeg(args)
51
+
52
+ def build_ACT_jpeg_diffusion_model(args):
53
+ return build_jpeg_diffusion(args)
54
+
55
+ def build_ACT_jpeg_diffusion_seperate_model(args):
56
+ return build_jpeg_diffusion_seperate(args)
57
+
58
+ def build_nf_diffusion_seperate_model(args):
59
+ return build_nf_diffusion_seperate(args)
60
+
ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (2.55 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.2 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc ADDED
Binary file (6.66 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc ADDED
Binary file (4.32 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc ADDED
Binary file (4.35 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc ADDED
Binary file (50.1 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc ADDED
Binary file (57.9 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc ADDED
Binary file (40.8 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc ADDED
Binary file (15.1 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc ADDED
Binary file (3.61 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc ADDED
Binary file (3.55 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc ADDED
Binary file (3.56 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (39 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc ADDED
Binary file (40.5 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (24.5 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc ADDED
Binary file (13.5 kB). View file
 
ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc ADDED
Binary file (13.2 kB). View file
 
ACT_DP_multitask/detr/models/backbone.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Backbone modules.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torch import nn
11
+ from torchvision.models._utils import IntermediateLayerGetter
12
+ from typing import Dict, List
13
+ from typing import Any, Dict, List, Mapping, Optional
14
+ from ..util.misc import NestedTensor, is_main_process
15
+
16
+ from .position_encoding import build_position_encoding
17
+ from .resnet_film import resnet18 as resnet18_film
18
+ from .resnet_film import resnet34 as resnet34_film
19
+ import IPython
20
+ e = IPython.embed
21
+
22
+ class FrozenBatchNorm2d(torch.nn.Module):
23
+ """
24
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
25
+
26
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
27
+ without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
28
+ produce nans.
29
+ """
30
+
31
+ def __init__(self, n):
32
+ super(FrozenBatchNorm2d, self).__init__()
33
+ self.register_buffer("weight", torch.ones(n))
34
+ self.register_buffer("bias", torch.zeros(n))
35
+ self.register_buffer("running_mean", torch.zeros(n))
36
+ self.register_buffer("running_var", torch.ones(n))
37
+
38
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
39
+ missing_keys, unexpected_keys, error_msgs):
40
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
41
+ if num_batches_tracked_key in state_dict:
42
+ del state_dict[num_batches_tracked_key]
43
+
44
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
45
+ state_dict, prefix, local_metadata, strict,
46
+ missing_keys, unexpected_keys, error_msgs)
47
+
48
+ def forward(self, x):
49
+ # move reshapes to the beginning
50
+ # to make it fuser-friendly
51
+ w = self.weight.reshape(1, -1, 1, 1)
52
+ b = self.bias.reshape(1, -1, 1, 1)
53
+ rv = self.running_var.reshape(1, -1, 1, 1)
54
+ rm = self.running_mean.reshape(1, -1, 1, 1)
55
+ eps = 1e-5
56
+ scale = w * (rv + eps).rsqrt()
57
+ bias = b - rm * scale
58
+ return x * scale + bias
59
+
60
+
61
+ class BackboneBase(nn.Module):
62
+
63
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
64
+ super().__init__()
65
+ # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
66
+ # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
67
+ # parameter.requires_grad_(False)
68
+ if return_interm_layers:
69
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
70
+ else:
71
+ return_layers = {'layer4': "0"}
72
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
73
+ self.num_channels = num_channels
74
+
75
+ def forward(self, tensor):
76
+ xs = self.body(tensor)
77
+ return xs
78
+ # out: Dict[str, NestedTensor] = {}
79
+ # for name, x in xs.items():
80
+ # m = tensor_list.mask
81
+ # assert m is not None
82
+ # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
83
+ # out[name] = NestedTensor(x, mask)
84
+ # return out
85
+
86
+
87
+ class Backbone(BackboneBase):
88
+ """ResNet backbone with frozen BatchNorm."""
89
+ def __init__(self, name: str,
90
+ train_backbone: bool,
91
+ return_interm_layers: bool,
92
+ dilation: bool):
93
+ backbone = getattr(torchvision.models, name)(
94
+ replace_stride_with_dilation=[False, False, dilation],
95
+ pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
96
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
97
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
98
+
99
+ # ==== ResNet Backbone ====
100
+ class ResNetFilmBackbone(nn.Module):
101
+ def __init__(self, embedding_name: str, pretrained: bool = False,
102
+ film_config: Optional[Mapping[str, Any]] = None):
103
+ super().__init__()
104
+ self._pretrained = pretrained
105
+ weights = 'IMAGENET1K_V1' if pretrained else None
106
+ if embedding_name in ('resnet34_film', 'resnet34'):
107
+ backbone = resnet34_film(weights=weights, film_config=film_config, pretrained=pretrained)
108
+ embedding_dim = 512
109
+ elif embedding_name in ('resnet18_film', 'resnet18'):
110
+ backbone = resnet18_film(weights=weights, film_config=film_config, pretrained=pretrained)
111
+ embedding_dim = 512
112
+ else:
113
+ raise NotImplementedError
114
+
115
+ self.resnet_film_model = backbone
116
+ self._embedding_dim = embedding_dim
117
+ self.resnet_film_model.fc = nn.Identity()
118
+ self.resnet_film_model.avgpool = nn.Identity()
119
+
120
+ self.num_channels = self._embedding_dim
121
+
122
+ # FiLM config
123
+ self.film_config = film_config
124
+ if film_config is not None and film_config['use']:
125
+ film_models = []
126
+ for layer_idx, num_blocks in enumerate(self.resnet_film_model.layers):
127
+ if layer_idx in film_config['use_in_layers']:
128
+ num_planes = self.resnet_film_model.film_planes[layer_idx]
129
+ film_model_layer = nn.Linear(
130
+ film_config['task_embedding_dim'], num_blocks * 2 * num_planes)
131
+ else:
132
+ film_model_layer = None
133
+ film_models.append(film_model_layer)
134
+
135
+ self.film_models = nn.ModuleList(film_models)
136
+
137
+ def forward(self, x, texts: Optional[List[str]] = None, task_emb: Optional[torch.Tensor] = None, **kwargs):
138
+ film_outputs = None
139
+ if self.film_config is not None and self.film_config['use']:
140
+ film_outputs = []
141
+ for layer_idx, num_blocks in enumerate(self.resnet_film_model.layers):
142
+ if self.film_config['use'] and self.film_models[layer_idx] is not None:
143
+ film_features = self.film_models[layer_idx](task_emb)
144
+ else:
145
+ film_features = None
146
+ film_outputs.append(film_features)
147
+ return self.resnet_film_model(x, film_features=film_outputs, flatten=False)
148
+
149
+ @property
150
+ def embed_dim(self):
151
+ return self._embedding_dim
152
+
153
+
154
+ # class Joiner(nn.Sequential):
155
+ # def __init__(self, backbone, position_embedding):
156
+ # super().__init__(backbone, position_embedding)
157
+
158
+ # def forward(self, tensor_list: NestedTensor, task_emb:NestedTensor):
159
+ # xs = self[0](tensor_list)
160
+ # out: List[NestedTensor] = []
161
+ # pos = []
162
+ # for name, x in xs.items():
163
+ # out.append(x)
164
+ # # position encoding
165
+ # pos.append(self[1](x).to(x.dtype))
166
+
167
+ # return out, pos
168
+
169
+ class Joiner(nn.Sequential):
170
+ def __init__(self, backbone, position_embedding):
171
+ super().__init__(backbone, position_embedding)
172
+
173
+ def forward(self, tensor_list: NestedTensor, task_emb: Optional[Any] = None):
174
+ if task_emb is not None:
175
+ xs = self[0](tensor_list, task_emb=task_emb)
176
+ # Make a dictionary out of the last layer outputs since we don't have IntermediateLayerGetter
177
+ xs = {'0': xs}
178
+ else:
179
+ xs = self[0](tensor_list)
180
+ out: List[NestedTensor] = []
181
+ pos = []
182
+ for name, x in xs.items():
183
+ out.append(x)
184
+ # position encoding
185
+ pos.append(self[1](x).to(x.dtype))
186
+
187
+ return out, pos
188
+
189
+ def build_backbone(args):
190
+ position_embedding = build_position_encoding(args)
191
+ train_backbone = args.lr_backbone > 0
192
+ return_interm_layers = args.masks
193
+ backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
194
+ model = Joiner(backbone, position_embedding)
195
+ model.num_channels = backbone.num_channels
196
+ return model
197
+
198
+ def build_film_backbone(args):
199
+ position_embedding = build_position_encoding(args)
200
+ film_config = {
201
+ 'use': True,
202
+ 'use_in_layers': [1, 2, 3],
203
+ 'task_embedding_dim': 512,
204
+ 'film_planes': [64, 128, 256, 512],
205
+ }
206
+ backbone = ResNetFilmBackbone(args.backbone, film_config=film_config)
207
+ model = Joiner(backbone, position_embedding)
208
+ model.num_channels = backbone.num_channels
209
+ return model
ACT_DP_multitask/detr/models/detr_vae.py ADDED
The diff for this file is too large to render. See raw diff
 
ACT_DP_multitask/detr/models/detr_vae_nfp.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ DETR model and criterion classes.
4
+ """
5
+ import torch
6
+ from torch import nn
7
+ from torch.autograd import Variable
8
+ from .backbone import build_backbone
9
+ from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
10
+ from .vision_transformer import Block, get_2d_sincos_pos_embed, get_2d_sincos_pos_embed_v2
11
+ from .mr_mg.policy.model.vision_transformer import vit_base
12
+
13
+ import numpy as np
14
+
15
+ import IPython
16
+ e = IPython.embed
17
+
18
+
19
+ def reparametrize(mu, logvar):
20
+ std = logvar.div(2).exp()
21
+ eps = Variable(std.data.new(std.size()).normal_())
22
+ return mu + std * eps
23
+
24
+
25
+ def get_sinusoid_encoding_table(n_position, d_hid):
26
+ def get_position_angle_vec(position):
27
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
28
+
29
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
30
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
31
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
32
+
33
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
34
+
35
+
36
+ class DETRVAE(nn.Module):
37
+ """ This is the DETR module that performs object detection """
38
+ def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
39
+ """ Initializes the model.
40
+ Parameters:
41
+ backbones: torch module of the backbone to be used. See backbone.py
42
+ transformer: torch module of the transformer architecture. See transformer.py
43
+ state_dim: robot state dimension of the environment
44
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
45
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
46
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
47
+ """
48
+ super().__init__()
49
+ self.num_queries = num_queries
50
+ self.camera_names = camera_names
51
+ self.transformer = transformer
52
+ self.encoder = encoder
53
+ hidden_dim = transformer.d_model
54
+ self.action_head = nn.Linear(hidden_dim, state_dim)
55
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
56
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
57
+ if backbones is not None:
58
+ self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
59
+ self.backbones = nn.ModuleList(backbones)
60
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
61
+ else:
62
+ # input_dim = 14 + 7 # robot_state + env_state
63
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
64
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
65
+ self.pos = torch.nn.Embedding(2, hidden_dim)
66
+ self.backbones = None
67
+
68
+ # encoder extra parameters
69
+ self.latent_dim = 32 # final size of latent z # TODO tune
70
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
71
+ self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
72
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
73
+ self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
74
+ self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
75
+
76
+ # decoder extra parameters
77
+ self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
78
+ self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
79
+
80
+ # settings for next frame prediction
81
+ self.patch_size = 16
82
+ # self.image_size = 224
83
+ # self.img_h, self.img_w = 128, 160
84
+ self.img_h, self.img_w = 224, 224
85
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
86
+ # self.n_patch = (self.image_size//self.patch_size)**2
87
+ self.k = 1 # number of next frames
88
+ self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)*(self.k)
89
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h)
90
+ self.patch_embed = nn.Embedding(self.n_patch, hidden_dim)
91
+ self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
92
+
93
+ decoder_depth = 2 # hardcode
94
+ self.decoder_blocks = nn.ModuleList([
95
+ Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
96
+ for i in range(decoder_depth)])
97
+
98
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
99
+ self.decoder_pred = nn.Linear(hidden_dim, self.patch_size**2 * 3, bias=True) # decoder to patch
100
+
101
+ # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False)
102
+ decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size))
103
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0).repeat(1,self.k,1))
104
+
105
+ # fwd_params = sum(p.numel() for p in self.decoder_blocks.parameters() if p.requires_grad)
106
+
107
+
108
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
109
+ """
110
+ qpos: batch, qpos_dim
111
+ image: batch, num_cam, channel, height, width
112
+ env_state: None
113
+ actions: batch, seq, action_dim
114
+ """
115
+ is_training = actions is not None # train or val
116
+ bs, _ = qpos.shape
117
+ ### Obtain latent z from action sequence
118
+ if is_training:
119
+ # project action sequence to embedding dim, and concat with a CLS token
120
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
121
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
122
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
123
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
124
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
125
+ encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
126
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
127
+ # do not mask cls token
128
+ cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
129
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
130
+ # obtain position embedding
131
+ pos_embed = self.pos_table.clone().detach()
132
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
133
+ # query model
134
+ encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
135
+ encoder_output = encoder_output[0] # take cls output only
136
+ latent_info = self.latent_proj(encoder_output)
137
+ mu = latent_info[:, :self.latent_dim]
138
+ logvar = latent_info[:, self.latent_dim:]
139
+ latent_sample = reparametrize(mu, logvar)
140
+ latent_input = self.latent_out_proj(latent_sample)
141
+ else:
142
+ mu = logvar = None
143
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
144
+ latent_input = self.latent_out_proj(latent_sample)
145
+
146
+ if self.backbones is not None:
147
+ # Image observation features and position embeddings
148
+ all_cam_features = []
149
+ all_cam_pos = []
150
+ if is_training:
151
+ next_frame_images = image[:,1:]
152
+ image = image[:,:1]
153
+ for cam_id, cam_name in enumerate(self.camera_names):
154
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED?
155
+ features = features[0] # take the last layer feature
156
+ pos = pos[0]
157
+ all_cam_features.append(self.input_proj(features))
158
+ all_cam_pos.append(pos)
159
+ # proprioception features
160
+ proprio_input = self.input_proj_robot_state(qpos)
161
+ # fold camera dimension into width dimension
162
+ src = torch.cat(all_cam_features, axis=3)
163
+ pos = torch.cat(all_cam_pos, axis=3)
164
+ query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0)
165
+ hs = self.transformer(src, None, query_embed, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
166
+ # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
167
+ else:
168
+ qpos = self.input_proj_robot_state(qpos)
169
+ env_state = self.input_proj_env_state(env_state)
170
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
171
+ hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
172
+ a_hat = self.action_head(hs[:,:self.num_queries])
173
+ is_pad_hat = self.is_pad_head(hs[:,:self.num_queries])
174
+
175
+ # next frame prediction
176
+ mask_token = self.mask_token
177
+ mask_tokens = mask_token.repeat(bs, self.n_patch, 1)
178
+ mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1)
179
+
180
+ obs_pred = self.decoder_embed(hs[:,self.num_queries:])
181
+ obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1)
182
+ for blk in self.decoder_blocks:
183
+ obs_pred_ = blk(obs_pred_)
184
+ obs_pred_ = self.decoder_norm(obs_pred_)
185
+ obs_preds = self.decoder_pred(obs_pred_)
186
+ obs_preds = obs_preds[:,self.n_patch:]
187
+
188
+ if is_training:
189
+ # next_frame_images = image[:,1:]
190
+ next_frame_images = nn.functional.interpolate(next_frame_images.reshape(bs, self.k*3, 224, 224), size=(self.img_h, self.img_w))
191
+ p = self.patch_size
192
+ h_p = self.img_h // p
193
+ w_p = self.img_w // p
194
+ obs_targets = next_frame_images.reshape(shape=(bs, self.k, 3, h_p, p, w_p, p))
195
+ obs_targets = obs_targets.permute(0,1,3,5,4,6,2)
196
+ obs_targets = obs_targets.reshape(shape=(bs, h_p*w_p*self.k, (p**2)*3))
197
+ else:
198
+ obs_targets = torch.zeros_like(obs_preds)
199
+
200
+ return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets]
201
+
202
+
203
+ class DETRVAE_MAE(nn.Module):
204
+ """ This is the DETR module that performs object detection """
205
+ def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
206
+ """ Initializes the model.
207
+ Parameters:
208
+ backbones: torch module of the backbone to be used. See backbone.py
209
+ transformer: torch module of the transformer architecture. See transformer.py
210
+ state_dim: robot state dimension of the environment
211
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
212
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
213
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
214
+ """
215
+ super().__init__()
216
+ self.num_queries = num_queries
217
+ self.camera_names = camera_names
218
+ self.transformer = transformer
219
+ self.encoder = encoder
220
+ hidden_dim = transformer.d_model
221
+ self.action_head = nn.Linear(hidden_dim, state_dim)
222
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
223
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
224
+
225
+ # self.model_mae = vits.__dict__['vit_base'](patch_size=16, num_classes=0)
226
+ self.model_mae = vit_base(patch_size=16, num_classes=0)
227
+ mae_ckpt = 'checkpoints/pretrained/mae_pretrain_vit_base.pth'
228
+ checkpoint = torch.load(mae_ckpt, map_location='cpu')
229
+ self.model_mae.load_state_dict(checkpoint['model'], strict=True)
230
+ print('Load MAE pretrained model')
231
+ # for name, p in self.model_mae.named_parameters():
232
+ # p.requires_grad = False
233
+
234
+ if backbones is not None:
235
+ self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
236
+ self.backbones = nn.ModuleList(backbones)
237
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
238
+ else:
239
+ # input_dim = 14 + 7 # robot_state + env_state
240
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
241
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
242
+ self.pos = torch.nn.Embedding(2, hidden_dim)
243
+ self.backbones = None
244
+
245
+ # encoder extra parameters
246
+ self.latent_dim = 32 # final size of latent z # TODO tune
247
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
248
+ self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
249
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
250
+ self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
251
+ self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
252
+
253
+ # decoder extra parameters
254
+ self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
255
+ self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
256
+
257
+ # settings for next frame prediction
258
+ self.patch_size = 16
259
+ self.img_h, self.img_w = 224, 224
260
+ self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)
261
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
262
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h)
263
+ self.patch_embed = nn.Embedding(self.n_patch, hidden_dim)
264
+ self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
265
+
266
+ decoder_depth = 2 # hardcode
267
+ self.decoder_blocks = nn.ModuleList([
268
+ Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
269
+ for i in range(decoder_depth)])
270
+
271
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
272
+ self.decoder_pred = nn.Linear(hidden_dim, self.patch_size**2 * 3, bias=True) # decoder to patch
273
+
274
+ decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size))
275
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
276
+
277
+
278
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
279
+ """
280
+ qpos: batch, qpos_dim
281
+ image: batch, num_cam, channel, height, width
282
+ env_state: None
283
+ actions: batch, seq, action_dim
284
+ """
285
+ is_training = actions is not None # train or val
286
+ bs, _ = qpos.shape
287
+ ### Obtain latent z from action sequence
288
+ if is_training:
289
+ # project action sequence to embedding dim, and concat with a CLS token
290
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
291
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
292
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
293
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
294
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
295
+ encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
296
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
297
+ # do not mask cls token
298
+ cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
299
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
300
+ # obtain position embedding
301
+ pos_embed = self.pos_table.clone().detach()
302
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
303
+ # query model
304
+ encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
305
+ encoder_output = encoder_output[0] # take cls output only
306
+ latent_info = self.latent_proj(encoder_output)
307
+ mu = latent_info[:, :self.latent_dim]
308
+ logvar = latent_info[:, self.latent_dim:]
309
+ latent_sample = reparametrize(mu, logvar)
310
+ latent_input = self.latent_out_proj(latent_sample)
311
+ else:
312
+ mu = logvar = None
313
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
314
+ latent_input = self.latent_out_proj(latent_sample)
315
+
316
+ if self.backbones is not None:
317
+ # Image observation features and position embeddings
318
+ all_cam_features = []
319
+ all_cam_pos = []
320
+ if is_training:
321
+ next_frame_images = image[:,1:]
322
+ image = image[:,:1]
323
+ for cam_id, cam_name in enumerate(self.camera_names):
324
+ # features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
325
+ # features = features[0] # take the last layer feature
326
+ # pos = pos[0]
327
+ # all_cam_features.append(self.input_proj(features))
328
+ # all_cam_pos.append(pos)
329
+
330
+ obs_embedings, patch_embedings, pos_mae = self.model_mae(image[:,cam_id])
331
+
332
+ # proprioception features
333
+ proprio_input = self.input_proj_robot_state(qpos)
334
+ # fold camera dimension into width dimension
335
+ # src = torch.cat(all_cam_features, axis=3)
336
+ # pos = torch.cat(all_cam_pos, axis=3)
337
+ query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0)
338
+ hs = self.transformer(patch_embedings, None, query_embed, pos_mae[0,1:], latent_input, proprio_input, self.additional_pos_embed.weight)[0]
339
+ # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
340
+ else:
341
+ qpos = self.input_proj_robot_state(qpos)
342
+ env_state = self.input_proj_env_state(env_state)
343
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
344
+ hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
345
+ a_hat = self.action_head(hs[:,:self.num_queries])
346
+ is_pad_hat = self.is_pad_head(hs[:,:self.num_queries])
347
+
348
+ # next frame prediction
349
+ mask_token = self.mask_token
350
+ mask_tokens = mask_token.repeat(bs, self.n_patch, 1)
351
+ mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1)
352
+
353
+ obs_pred = self.decoder_embed(hs[:,self.num_queries:])
354
+ obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1)
355
+ for blk in self.decoder_blocks:
356
+ obs_pred_ = blk(obs_pred_)
357
+ obs_pred_ = self.decoder_norm(obs_pred_)
358
+ obs_preds = self.decoder_pred(obs_pred_)
359
+ obs_preds = obs_preds[:,self.n_patch:]
360
+
361
+ if is_training:
362
+ # next_frame_images = image[:,1:]
363
+ # next_frame_images = nn.functional.interpolate(next_frame_images[:,0], size=(self.img_h, self.img_w))
364
+ next_frame_images = next_frame_images[:,0]
365
+ p = self.patch_size
366
+ h_p = self.img_h // p
367
+ w_p = self.img_w // p
368
+ obs_targets = next_frame_images.reshape(shape=(bs, 3, h_p, p, w_p, p))
369
+ obs_targets = obs_targets.permute(0,2,4,3,5,1)
370
+ obs_targets = obs_targets.reshape(shape=(bs, h_p*w_p, (p**2)*3))
371
+ else:
372
+ obs_targets = torch.zeros_like(obs_preds)
373
+
374
+
375
+ return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets]
376
+
377
+
378
+ class CNNMLP(nn.Module):
379
+ def __init__(self, backbones, state_dim, camera_names):
380
+ """ Initializes the model.
381
+ Parameters:
382
+ backbones: torch module of the backbone to be used. See backbone.py
383
+ transformer: torch module of the transformer architecture. See transformer.py
384
+ state_dim: robot state dimension of the environment
385
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
386
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
387
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
388
+ """
389
+ super().__init__()
390
+ self.camera_names = camera_names
391
+ self.action_head = nn.Linear(1000, state_dim) # TODO add more
392
+ if backbones is not None:
393
+ self.backbones = nn.ModuleList(backbones)
394
+ backbone_down_projs = []
395
+ for backbone in backbones:
396
+ down_proj = nn.Sequential(
397
+ nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
398
+ nn.Conv2d(128, 64, kernel_size=5),
399
+ nn.Conv2d(64, 32, kernel_size=5)
400
+ )
401
+ backbone_down_projs.append(down_proj)
402
+ self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
403
+
404
+ mlp_in_dim = 768 * len(backbones) + 14
405
+ self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
406
+ else:
407
+ raise NotImplementedError
408
+
409
+ def forward(self, qpos, image, env_state, actions=None):
410
+ """
411
+ qpos: batch, qpos_dim
412
+ image: batch, num_cam, channel, height, width
413
+ env_state: None
414
+ actions: batch, seq, action_dim
415
+ """
416
+ is_training = actions is not None # train or val
417
+ bs, _ = qpos.shape
418
+ # Image observation features and position embeddings
419
+ all_cam_features = []
420
+ for cam_id, cam_name in enumerate(self.camera_names):
421
+ features, pos = self.backbones[cam_id](image[:, cam_id])
422
+ features = features[0] # take the last layer feature
423
+ pos = pos[0] # not used
424
+ all_cam_features.append(self.backbone_down_projs[cam_id](features))
425
+ # flatten everything
426
+ flattened_features = []
427
+ for cam_feature in all_cam_features:
428
+ flattened_features.append(cam_feature.reshape([bs, -1]))
429
+ flattened_features = torch.cat(flattened_features, axis=1) # 768 each
430
+ features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
431
+ a_hat = self.mlp(features)
432
+ return a_hat
433
+
434
+
435
+ def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
436
+ if hidden_depth == 0:
437
+ mods = [nn.Linear(input_dim, output_dim)]
438
+ else:
439
+ mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
440
+ for i in range(hidden_depth - 1):
441
+ mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
442
+ mods.append(nn.Linear(hidden_dim, output_dim))
443
+ trunk = nn.Sequential(*mods)
444
+ return trunk
445
+
446
+
447
+ def build_encoder(args):
448
+ d_model = args.hidden_dim # 256
449
+ dropout = args.dropout # 0.1
450
+ nhead = args.nheads # 8
451
+ dim_feedforward = args.dim_feedforward # 2048
452
+ num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
453
+ normalize_before = args.pre_norm # False
454
+ activation = "relu"
455
+
456
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
457
+ dropout, activation, normalize_before)
458
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
459
+ encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
460
+
461
+ return encoder
462
+
463
+
464
+ def build(args):
465
+ state_dim = 14 # TODO hardcode
466
+
467
+ # From state
468
+ # backbone = None # from state for now, no need for conv nets
469
+ # From image
470
+ backbones = []
471
+ backbone = build_backbone(args)
472
+ backbones.append(backbone)
473
+
474
+ transformer = build_transformer(args)
475
+
476
+ encoder = build_encoder(args)
477
+
478
+ if not args.mae:
479
+ model = DETRVAE(
480
+ backbones,
481
+ transformer,
482
+ encoder,
483
+ state_dim=state_dim,
484
+ num_queries=args.num_queries,
485
+ camera_names=args.camera_names,
486
+ )
487
+ else:
488
+ model = DETRVAE_MAE(
489
+ backbones,
490
+ transformer,
491
+ encoder,
492
+ state_dim=state_dim,
493
+ num_queries=args.num_queries,
494
+ camera_names=args.camera_names,
495
+ )
496
+
497
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
498
+ print("number of parameters: %.2fM" % (n_parameters/1e6,))
499
+
500
+ return model
501
+
502
+ def build_cnnmlp(args):
503
+ state_dim = 14 # TODO hardcode
504
+
505
+ # From state
506
+ # backbone = None # from state for now, no need for conv nets
507
+ # From image
508
+ backbones = []
509
+ for _ in args.camera_names:
510
+ backbone = build_backbone(args)
511
+ backbones.append(backbone)
512
+
513
+ model = CNNMLP(
514
+ backbones,
515
+ state_dim=state_dim,
516
+ camera_names=args.camera_names,
517
+ )
518
+
519
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
520
+ print("number of parameters: %.2fM" % (n_parameters/1e6,))
521
+
522
+ return model
523
+
ACT_DP_multitask/detr/models/mask_former/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import data # register all new datasets
3
+ from . import modeling
4
+
5
+ # config
6
+ from .config import add_mask_former_config
7
+
8
+ # dataset loading
9
+ from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
10
+ from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
11
+ MaskFormerPanopticDatasetMapper,
12
+ )
13
+ from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
14
+ MaskFormerSemanticDatasetMapper,
15
+ )
16
+
17
+ # models
18
+ from .mask_former_model import MaskFormer
19
+ from .test_time_augmentation import SemanticSegmentorWithTTA
ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (725 Bytes). View file
 
ACT_DP_multitask/detr/models/mask_former/config.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from detectron2.config import CfgNode as CN
4
+
5
+
6
+ def add_mask_former_config(cfg):
7
+ """
8
+ Add config for MASK_FORMER.
9
+ """
10
+ # data config
11
+ # select the dataset mapper
12
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
13
+ # Color augmentation
14
+ cfg.INPUT.COLOR_AUG_SSD = False
15
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
16
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
17
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
18
+ # Pad image and segmentation GT in dataset mapper.
19
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
20
+
21
+ # solver config
22
+ # weight decay on embedding
23
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
24
+ # optimizer
25
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
26
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
27
+
28
+ # mask_former model config
29
+ cfg.MODEL.MASK_FORMER = CN()
30
+
31
+ # loss
32
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
33
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
34
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
35
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
36
+
37
+ # transformer config
38
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
39
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
40
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
41
+ cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
42
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
43
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
44
+
45
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
46
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
47
+
48
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
49
+ cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
50
+
51
+ # mask_former inference config
52
+ cfg.MODEL.MASK_FORMER.TEST = CN()
53
+ cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
54
+ cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
55
+ cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
56
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
57
+
58
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
59
+ # you can use this config to override
60
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
61
+
62
+ # pixel decoder config
63
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
64
+ # adding transformer in pixel decoder
65
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
66
+ # pixel decoder
67
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
68
+
69
+ # swin transformer backbone
70
+ cfg.MODEL.SWIN = CN()
71
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
72
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
73
+ cfg.MODEL.SWIN.EMBED_DIM = 96
74
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
75
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
76
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
77
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
78
+ cfg.MODEL.SWIN.QKV_BIAS = True
79
+ cfg.MODEL.SWIN.QK_SCALE = None
80
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
81
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
82
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
83
+ cfg.MODEL.SWIN.APE = False
84
+ cfg.MODEL.SWIN.PATCH_NORM = True
85
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
ACT_DP_multitask/detr/models/mask_former/mask_former_model.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from detectron2.config import configurable
9
+ from detectron2.data import MetadataCatalog
10
+ from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
11
+ from detectron2.modeling.backbone import Backbone
12
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
13
+ from detectron2.structures import ImageList
14
+
15
+ from .modeling.criterion import SetCriterion
16
+ from .modeling.matcher import HungarianMatcher
17
+
18
+
19
+ @META_ARCH_REGISTRY.register()
20
+ class MaskFormer(nn.Module):
21
+ """
22
+ Main class for mask classification semantic segmentation architectures.
23
+ """
24
+
25
+ @configurable
26
+ def __init__(
27
+ self,
28
+ *,
29
+ backbone: Backbone,
30
+ sem_seg_head: nn.Module,
31
+ criterion: nn.Module,
32
+ num_queries: int,
33
+ panoptic_on: bool,
34
+ object_mask_threshold: float,
35
+ overlap_threshold: float,
36
+ metadata,
37
+ size_divisibility: int,
38
+ sem_seg_postprocess_before_inference: bool,
39
+ pixel_mean: Tuple[float],
40
+ pixel_std: Tuple[float],
41
+ ):
42
+ """
43
+ Args:
44
+ backbone: a backbone module, must follow detectron2's backbone interface
45
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
46
+ criterion: a module that defines the loss
47
+ num_queries: int, number of queries
48
+ panoptic_on: bool, whether to output panoptic segmentation prediction
49
+ object_mask_threshold: float, threshold to filter query based on classification score
50
+ for panoptic segmentation inference
51
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
52
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
53
+ segmentation inference
54
+ size_divisibility: Some backbones require the input height and width to be divisible by a
55
+ specific integer. We can use this to override such requirement.
56
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
57
+ to original input size before semantic segmentation inference or after.
58
+ For high-resolution dataset like Mapillary, resizing predictions before
59
+ inference will cause OOM error.
60
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
61
+ the per-channel mean and std to be used to normalize the input image
62
+ """
63
+ super().__init__()
64
+ self.backbone = backbone
65
+ self.sem_seg_head = sem_seg_head
66
+ self.criterion = criterion
67
+ self.num_queries = num_queries
68
+ self.overlap_threshold = overlap_threshold
69
+ self.panoptic_on = panoptic_on
70
+ self.object_mask_threshold = object_mask_threshold
71
+ self.metadata = metadata
72
+ if size_divisibility < 0:
73
+ # use backbone size_divisibility if not set
74
+ size_divisibility = self.backbone.size_divisibility
75
+ self.size_divisibility = size_divisibility
76
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
77
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
78
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
79
+
80
+ @classmethod
81
+ def from_config(cls, cfg):
82
+ backbone = build_backbone(cfg)
83
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
84
+
85
+ # Loss parameters:
86
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
87
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
88
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
89
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
90
+
91
+ # building criterion
92
+ matcher = HungarianMatcher(
93
+ cost_class=1,
94
+ cost_mask=mask_weight,
95
+ cost_dice=dice_weight,
96
+ )
97
+
98
+ weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight}
99
+ if deep_supervision:
100
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
101
+ aux_weight_dict = {}
102
+ for i in range(dec_layers - 1):
103
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
104
+ weight_dict.update(aux_weight_dict)
105
+
106
+ losses = ["labels", "masks"]
107
+
108
+ criterion = SetCriterion(
109
+ sem_seg_head.num_classes,
110
+ matcher=matcher,
111
+ weight_dict=weight_dict,
112
+ eos_coef=no_object_weight,
113
+ losses=losses,
114
+ )
115
+
116
+ return {
117
+ "backbone": backbone,
118
+ "sem_seg_head": sem_seg_head,
119
+ "criterion": criterion,
120
+ "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
121
+ "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
122
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
123
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
124
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
125
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
126
+ "sem_seg_postprocess_before_inference": (
127
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
128
+ or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
129
+ ),
130
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
131
+ "pixel_std": cfg.MODEL.PIXEL_STD,
132
+ }
133
+
134
+ @property
135
+ def device(self):
136
+ return self.pixel_mean.device
137
+
138
+ def forward(self, batched_inputs):
139
+ """
140
+ Args:
141
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
142
+ Each item in the list contains the inputs for one image.
143
+ For now, each item in the list is a dict that contains:
144
+ * "image": Tensor, image in (C, H, W) format.
145
+ * "instances": per-region ground truth
146
+ * Other information that's included in the original dicts, such as:
147
+ "height", "width" (int): the output resolution of the model (may be different
148
+ from input resolution), used in inference.
149
+ Returns:
150
+ list[dict]:
151
+ each dict has the results for one image. The dict contains the following keys:
152
+
153
+ * "sem_seg":
154
+ A Tensor that represents the
155
+ per-pixel segmentation prediced by the head.
156
+ The prediction has shape KxHxW that represents the logits of
157
+ each class for each pixel.
158
+ * "panoptic_seg":
159
+ A tuple that represent panoptic output
160
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
161
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
162
+ Each dict contains keys "id", "category_id", "isthing".
163
+ """
164
+ images = [x["image"].to(self.device) for x in batched_inputs]
165
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
166
+ images = ImageList.from_tensors(images, self.size_divisibility)
167
+
168
+ features = self.backbone(images.tensor)
169
+ outputs = self.sem_seg_head(features)
170
+
171
+ if self.training:
172
+ # mask classification target
173
+ if "instances" in batched_inputs[0]:
174
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
175
+ targets = self.prepare_targets(gt_instances, images)
176
+ else:
177
+ targets = None
178
+
179
+ # bipartite matching-based loss
180
+ losses = self.criterion(outputs, targets)
181
+
182
+ for k in list(losses.keys()):
183
+ if k in self.criterion.weight_dict:
184
+ losses[k] *= self.criterion.weight_dict[k]
185
+ else:
186
+ # remove this loss if not specified in `weight_dict`
187
+ losses.pop(k)
188
+
189
+ return losses
190
+ else:
191
+ mask_cls_results = outputs["pred_logits"]
192
+ mask_pred_results = outputs["pred_masks"]
193
+ # upsample masks
194
+ mask_pred_results = F.interpolate(
195
+ mask_pred_results,
196
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
197
+ mode="bilinear",
198
+ align_corners=False,
199
+ )
200
+
201
+ processed_results = []
202
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
203
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
204
+ ):
205
+ height = input_per_image.get("height", image_size[0])
206
+ width = input_per_image.get("width", image_size[1])
207
+
208
+ if self.sem_seg_postprocess_before_inference:
209
+ mask_pred_result = sem_seg_postprocess(
210
+ mask_pred_result, image_size, height, width
211
+ )
212
+
213
+ # semantic segmentation inference
214
+ r = self.semantic_inference(mask_cls_result, mask_pred_result)
215
+ if not self.sem_seg_postprocess_before_inference:
216
+ r = sem_seg_postprocess(r, image_size, height, width)
217
+ processed_results.append({"sem_seg": r})
218
+
219
+ # panoptic segmentation inference
220
+ if self.panoptic_on:
221
+ panoptic_r = self.panoptic_inference(mask_cls_result, mask_pred_result)
222
+ processed_results[-1]["panoptic_seg"] = panoptic_r
223
+
224
+ return processed_results
225
+
226
+ def prepare_targets(self, targets, images):
227
+ h, w = images.tensor.shape[-2:]
228
+ new_targets = []
229
+ for targets_per_image in targets:
230
+ # pad gt
231
+ gt_masks = targets_per_image.gt_masks
232
+ padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device)
233
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
234
+ new_targets.append(
235
+ {
236
+ "labels": targets_per_image.gt_classes,
237
+ "masks": padded_masks,
238
+ }
239
+ )
240
+ return new_targets
241
+
242
+ def semantic_inference(self, mask_cls, mask_pred):
243
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
244
+ mask_pred = mask_pred.sigmoid()
245
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
246
+ return semseg
247
+
248
+ def panoptic_inference(self, mask_cls, mask_pred):
249
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
250
+ mask_pred = mask_pred.sigmoid()
251
+
252
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
253
+ cur_scores = scores[keep]
254
+ cur_classes = labels[keep]
255
+ cur_masks = mask_pred[keep]
256
+ cur_mask_cls = mask_cls[keep]
257
+ cur_mask_cls = cur_mask_cls[:, :-1]
258
+
259
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
260
+
261
+ h, w = cur_masks.shape[-2:]
262
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
263
+ segments_info = []
264
+
265
+ current_segment_id = 0
266
+
267
+ if cur_masks.shape[0] == 0:
268
+ # We didn't detect any mask :(
269
+ return panoptic_seg, segments_info
270
+ else:
271
+ # take argmax
272
+ cur_mask_ids = cur_prob_masks.argmax(0)
273
+ stuff_memory_list = {}
274
+ for k in range(cur_classes.shape[0]):
275
+ pred_class = cur_classes[k].item()
276
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
277
+ mask = cur_mask_ids == k
278
+ mask_area = mask.sum().item()
279
+ original_area = (cur_masks[k] >= 0.5).sum().item()
280
+
281
+ if mask_area > 0 and original_area > 0:
282
+ if mask_area / original_area < self.overlap_threshold:
283
+ continue
284
+
285
+ # merge stuff regions
286
+ if not isthing:
287
+ if int(pred_class) in stuff_memory_list.keys():
288
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
289
+ continue
290
+ else:
291
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
292
+
293
+ current_segment_id += 1
294
+ panoptic_seg[mask] = current_segment_id
295
+
296
+ segments_info.append(
297
+ {
298
+ "id": current_segment_id,
299
+ "isthing": bool(isthing),
300
+ "category_id": int(pred_class),
301
+ }
302
+ )
303
+
304
+ return panoptic_seg, segments_info
ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from .backbone.swin import D2SwinTransformer
3
+ from .heads.mask_former_head import MaskFormerHead
4
+ from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead
5
+ from .heads.pixel_decoder import BasePixelDecoder
ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
17
+
18
+ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ """Multilayer perceptron."""
23
+
24
+ def __init__(
25
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
26
+ ):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, H, W, C)
48
+ window_size (int): window size
49
+ Returns:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ """
52
+ B, H, W, C = x.shape
53
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
70
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
71
+ return x
72
+
73
+
74
+ class WindowAttention(nn.Module):
75
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
76
+ It supports both of shifted and non-shifted window.
77
+ Args:
78
+ dim (int): Number of input channels.
79
+ window_size (tuple[int]): The height and width of the window.
80
+ num_heads (int): Number of attention heads.
81
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
82
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
83
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
84
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ window_size,
91
+ num_heads,
92
+ qkv_bias=True,
93
+ qk_scale=None,
94
+ attn_drop=0.0,
95
+ proj_drop=0.0,
96
+ ):
97
+
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.window_size = window_size # Wh, Ww
101
+ self.num_heads = num_heads
102
+ head_dim = dim // num_heads
103
+ self.scale = qk_scale or head_dim ** -0.5
104
+
105
+ # define a parameter table of relative position bias
106
+ self.relative_position_bias_table = nn.Parameter(
107
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
108
+ ) # 2*Wh-1 * 2*Ww-1, nH
109
+
110
+ # get pair-wise relative position index for each token inside the window
111
+ coords_h = torch.arange(self.window_size[0])
112
+ coords_w = torch.arange(self.window_size[1])
113
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
114
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
115
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
116
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
117
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
118
+ relative_coords[:, :, 1] += self.window_size[1] - 1
119
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
120
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
121
+ self.register_buffer("relative_position_index", relative_position_index)
122
+
123
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ self.proj = nn.Linear(dim, dim)
126
+ self.proj_drop = nn.Dropout(proj_drop)
127
+
128
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
129
+ self.softmax = nn.Softmax(dim=-1)
130
+
131
+ def forward(self, x, mask=None):
132
+ """Forward function.
133
+ Args:
134
+ x: input features with shape of (num_windows*B, N, C)
135
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
136
+ """
137
+ B_, N, C = x.shape
138
+ qkv = (
139
+ self.qkv(x)
140
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
141
+ .permute(2, 0, 3, 1, 4)
142
+ )
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = q @ k.transpose(-2, -1)
147
+
148
+ relative_position_bias = self.relative_position_bias_table[
149
+ self.relative_position_index.view(-1)
150
+ ].view(
151
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
152
+ ) # Wh*Ww,Wh*Ww,nH
153
+ relative_position_bias = relative_position_bias.permute(
154
+ 2, 0, 1
155
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
156
+ attn = attn + relative_position_bias.unsqueeze(0)
157
+
158
+ if mask is not None:
159
+ nW = mask.shape[0]
160
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
161
+ attn = attn.view(-1, self.num_heads, N, N)
162
+ attn = self.softmax(attn)
163
+ else:
164
+ attn = self.softmax(attn)
165
+
166
+ attn = self.attn_drop(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
169
+ x = self.proj(x)
170
+ x = self.proj_drop(x)
171
+ return x
172
+
173
+
174
+ class SwinTransformerBlock(nn.Module):
175
+ """Swin Transformer Block.
176
+ Args:
177
+ dim (int): Number of input channels.
178
+ num_heads (int): Number of attention heads.
179
+ window_size (int): Window size.
180
+ shift_size (int): Shift size for SW-MSA.
181
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
182
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
183
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
184
+ drop (float, optional): Dropout rate. Default: 0.0
185
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
186
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
187
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
188
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ dim,
194
+ num_heads,
195
+ window_size=7,
196
+ shift_size=0,
197
+ mlp_ratio=4.0,
198
+ qkv_bias=True,
199
+ qk_scale=None,
200
+ drop=0.0,
201
+ attn_drop=0.0,
202
+ drop_path=0.0,
203
+ act_layer=nn.GELU,
204
+ norm_layer=nn.LayerNorm,
205
+ ):
206
+ super().__init__()
207
+ self.dim = dim
208
+ self.num_heads = num_heads
209
+ self.window_size = window_size
210
+ self.shift_size = shift_size
211
+ self.mlp_ratio = mlp_ratio
212
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
213
+
214
+ self.norm1 = norm_layer(dim)
215
+ self.attn = WindowAttention(
216
+ dim,
217
+ window_size=to_2tuple(self.window_size),
218
+ num_heads=num_heads,
219
+ qkv_bias=qkv_bias,
220
+ qk_scale=qk_scale,
221
+ attn_drop=attn_drop,
222
+ proj_drop=drop,
223
+ )
224
+
225
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
226
+ self.norm2 = norm_layer(dim)
227
+ mlp_hidden_dim = int(dim * mlp_ratio)
228
+ self.mlp = Mlp(
229
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
230
+ )
231
+
232
+ self.H = None
233
+ self.W = None
234
+
235
+ def forward(self, x, mask_matrix):
236
+ """Forward function.
237
+ Args:
238
+ x: Input feature, tensor size (B, H*W, C).
239
+ H, W: Spatial resolution of the input feature.
240
+ mask_matrix: Attention mask for cyclic shift.
241
+ """
242
+ B, L, C = x.shape
243
+ H, W = self.H, self.W
244
+ assert L == H * W, "input feature has wrong size"
245
+
246
+ shortcut = x
247
+ x = self.norm1(x)
248
+ x = x.view(B, H, W, C)
249
+
250
+ # pad feature maps to multiples of window size
251
+ pad_l = pad_t = 0
252
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
253
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
254
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
255
+ _, Hp, Wp, _ = x.shape
256
+
257
+ # cyclic shift
258
+ if self.shift_size > 0:
259
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
260
+ attn_mask = mask_matrix
261
+ else:
262
+ shifted_x = x
263
+ attn_mask = None
264
+
265
+ # partition windows
266
+ x_windows = window_partition(
267
+ shifted_x, self.window_size
268
+ ) # nW*B, window_size, window_size, C
269
+ x_windows = x_windows.view(
270
+ -1, self.window_size * self.window_size, C
271
+ ) # nW*B, window_size*window_size, C
272
+
273
+ # W-MSA/SW-MSA
274
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
275
+
276
+ # merge windows
277
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
278
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
279
+
280
+ # reverse cyclic shift
281
+ if self.shift_size > 0:
282
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
283
+ else:
284
+ x = shifted_x
285
+
286
+ if pad_r > 0 or pad_b > 0:
287
+ x = x[:, :H, :W, :].contiguous()
288
+
289
+ x = x.view(B, H * W, C)
290
+
291
+ # FFN
292
+ x = shortcut + self.drop_path(x)
293
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
294
+
295
+ return x
296
+
297
+
298
+ class PatchMerging(nn.Module):
299
+ """Patch Merging Layer
300
+ Args:
301
+ dim (int): Number of input channels.
302
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
303
+ """
304
+
305
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
306
+ super().__init__()
307
+ self.dim = dim
308
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
309
+ self.norm = norm_layer(4 * dim)
310
+
311
+ def forward(self, x, H, W):
312
+ """Forward function.
313
+ Args:
314
+ x: Input feature, tensor size (B, H*W, C).
315
+ H, W: Spatial resolution of the input feature.
316
+ """
317
+ B, L, C = x.shape
318
+ assert L == H * W, "input feature has wrong size"
319
+
320
+ x = x.view(B, H, W, C)
321
+
322
+ # padding
323
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
324
+ if pad_input:
325
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+
340
+ class BasicLayer(nn.Module):
341
+ """A basic Swin Transformer layer for one stage.
342
+ Args:
343
+ dim (int): Number of feature channels
344
+ depth (int): Depths of this stage.
345
+ num_heads (int): Number of attention head.
346
+ window_size (int): Local window size. Default: 7.
347
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
348
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
349
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
350
+ drop (float, optional): Dropout rate. Default: 0.0
351
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
352
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
353
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
354
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
355
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ dim,
361
+ depth,
362
+ num_heads,
363
+ window_size=7,
364
+ mlp_ratio=4.0,
365
+ qkv_bias=True,
366
+ qk_scale=None,
367
+ drop=0.0,
368
+ attn_drop=0.0,
369
+ drop_path=0.0,
370
+ norm_layer=nn.LayerNorm,
371
+ downsample=None,
372
+ use_checkpoint=False,
373
+ ):
374
+ super().__init__()
375
+ self.window_size = window_size
376
+ self.shift_size = window_size // 2
377
+ self.depth = depth
378
+ self.use_checkpoint = use_checkpoint
379
+
380
+ # build blocks
381
+ self.blocks = nn.ModuleList(
382
+ [
383
+ SwinTransformerBlock(
384
+ dim=dim,
385
+ num_heads=num_heads,
386
+ window_size=window_size,
387
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
388
+ mlp_ratio=mlp_ratio,
389
+ qkv_bias=qkv_bias,
390
+ qk_scale=qk_scale,
391
+ drop=drop,
392
+ attn_drop=attn_drop,
393
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
394
+ norm_layer=norm_layer,
395
+ )
396
+ for i in range(depth)
397
+ ]
398
+ )
399
+
400
+ # patch merging layer
401
+ if downsample is not None:
402
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
403
+ else:
404
+ self.downsample = None
405
+
406
+ def forward(self, x, H, W):
407
+ """Forward function.
408
+ Args:
409
+ x: Input feature, tensor size (B, H*W, C).
410
+ H, W: Spatial resolution of the input feature.
411
+ """
412
+
413
+ # calculate attention mask for SW-MSA
414
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
415
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
416
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
417
+ h_slices = (
418
+ slice(0, -self.window_size),
419
+ slice(-self.window_size, -self.shift_size),
420
+ slice(-self.shift_size, None),
421
+ )
422
+ w_slices = (
423
+ slice(0, -self.window_size),
424
+ slice(-self.window_size, -self.shift_size),
425
+ slice(-self.shift_size, None),
426
+ )
427
+ cnt = 0
428
+ for h in h_slices:
429
+ for w in w_slices:
430
+ img_mask[:, h, w, :] = cnt
431
+ cnt += 1
432
+
433
+ mask_windows = window_partition(
434
+ img_mask, self.window_size
435
+ ) # nW, window_size, window_size, 1
436
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
437
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
438
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
439
+ attn_mask == 0, float(0.0)
440
+ )
441
+
442
+ for blk in self.blocks:
443
+ blk.H, blk.W = H, W
444
+ if self.use_checkpoint:
445
+ x = checkpoint.checkpoint(blk, x, attn_mask)
446
+ else:
447
+ x = blk(x, attn_mask)
448
+ if self.downsample is not None:
449
+ x_down = self.downsample(x, H, W)
450
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
451
+ return x, H, W, x_down, Wh, Ww
452
+ else:
453
+ return x, H, W, x, H, W
454
+
455
+
456
+ class PatchEmbed(nn.Module):
457
+ """Image to Patch Embedding
458
+ Args:
459
+ patch_size (int): Patch token size. Default: 4.
460
+ in_chans (int): Number of input image channels. Default: 3.
461
+ embed_dim (int): Number of linear projection output channels. Default: 96.
462
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
463
+ """
464
+
465
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
466
+ super().__init__()
467
+ patch_size = to_2tuple(patch_size)
468
+ self.patch_size = patch_size
469
+
470
+ self.in_chans = in_chans
471
+ self.embed_dim = embed_dim
472
+
473
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
474
+ if norm_layer is not None:
475
+ self.norm = norm_layer(embed_dim)
476
+ else:
477
+ self.norm = None
478
+
479
+ def forward(self, x):
480
+ """Forward function."""
481
+ # padding
482
+ _, _, H, W = x.size()
483
+ if W % self.patch_size[1] != 0:
484
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
485
+ if H % self.patch_size[0] != 0:
486
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
487
+
488
+ x = self.proj(x) # B C Wh Ww
489
+ if self.norm is not None:
490
+ Wh, Ww = x.size(2), x.size(3)
491
+ x = x.flatten(2).transpose(1, 2)
492
+ x = self.norm(x)
493
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
494
+
495
+ return x
496
+
497
+
498
+ class SwinTransformer(nn.Module):
499
+ """Swin Transformer backbone.
500
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
501
+ https://arxiv.org/pdf/2103.14030
502
+ Args:
503
+ pretrain_img_size (int): Input image size for training the pretrained model,
504
+ used in absolute postion embedding. Default 224.
505
+ patch_size (int | tuple(int)): Patch size. Default: 4.
506
+ in_chans (int): Number of input image channels. Default: 3.
507
+ embed_dim (int): Number of linear projection output channels. Default: 96.
508
+ depths (tuple[int]): Depths of each Swin Transformer stage.
509
+ num_heads (tuple[int]): Number of attention head of each stage.
510
+ window_size (int): Window size. Default: 7.
511
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
512
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
513
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
514
+ drop_rate (float): Dropout rate.
515
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
516
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
517
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
518
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
519
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
520
+ out_indices (Sequence[int]): Output from which stages.
521
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
522
+ -1 means not freezing any parameters.
523
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ pretrain_img_size=224,
529
+ patch_size=4,
530
+ in_chans=3,
531
+ embed_dim=96,
532
+ depths=[2, 2, 6, 2],
533
+ num_heads=[3, 6, 12, 24],
534
+ window_size=7,
535
+ mlp_ratio=4.0,
536
+ qkv_bias=True,
537
+ qk_scale=None,
538
+ drop_rate=0.0,
539
+ attn_drop_rate=0.0,
540
+ drop_path_rate=0.2,
541
+ norm_layer=nn.LayerNorm,
542
+ ape=False,
543
+ patch_norm=True,
544
+ out_indices=(0, 1, 2, 3),
545
+ frozen_stages=-1,
546
+ use_checkpoint=False,
547
+ ):
548
+ super().__init__()
549
+
550
+ self.pretrain_img_size = pretrain_img_size
551
+ self.num_layers = len(depths)
552
+ self.embed_dim = embed_dim
553
+ self.ape = ape
554
+ self.patch_norm = patch_norm
555
+ self.out_indices = out_indices
556
+ self.frozen_stages = frozen_stages
557
+
558
+ # split image into non-overlapping patches
559
+ self.patch_embed = PatchEmbed(
560
+ patch_size=patch_size,
561
+ in_chans=in_chans,
562
+ embed_dim=embed_dim,
563
+ norm_layer=norm_layer if self.patch_norm else None,
564
+ )
565
+
566
+ # absolute position embedding
567
+ if self.ape:
568
+ pretrain_img_size = to_2tuple(pretrain_img_size)
569
+ patch_size = to_2tuple(patch_size)
570
+ patches_resolution = [
571
+ pretrain_img_size[0] // patch_size[0],
572
+ pretrain_img_size[1] // patch_size[1],
573
+ ]
574
+
575
+ self.absolute_pos_embed = nn.Parameter(
576
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
577
+ )
578
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
579
+
580
+ self.pos_drop = nn.Dropout(p=drop_rate)
581
+
582
+ # stochastic depth
583
+ dpr = [
584
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
585
+ ] # stochastic depth decay rule
586
+
587
+ # build layers
588
+ self.layers = nn.ModuleList()
589
+ for i_layer in range(self.num_layers):
590
+ layer = BasicLayer(
591
+ dim=int(embed_dim * 2 ** i_layer),
592
+ depth=depths[i_layer],
593
+ num_heads=num_heads[i_layer],
594
+ window_size=window_size,
595
+ mlp_ratio=mlp_ratio,
596
+ qkv_bias=qkv_bias,
597
+ qk_scale=qk_scale,
598
+ drop=drop_rate,
599
+ attn_drop=attn_drop_rate,
600
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
601
+ norm_layer=norm_layer,
602
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
603
+ use_checkpoint=use_checkpoint,
604
+ )
605
+ self.layers.append(layer)
606
+
607
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
608
+ self.num_features = num_features
609
+
610
+ # add a norm layer for each output
611
+ for i_layer in out_indices:
612
+ layer = norm_layer(num_features[i_layer])
613
+ layer_name = f"norm{i_layer}"
614
+ self.add_module(layer_name, layer)
615
+
616
+ self._freeze_stages()
617
+
618
+ def _freeze_stages(self):
619
+ if self.frozen_stages >= 0:
620
+ self.patch_embed.eval()
621
+ for param in self.patch_embed.parameters():
622
+ param.requires_grad = False
623
+
624
+ if self.frozen_stages >= 1 and self.ape:
625
+ self.absolute_pos_embed.requires_grad = False
626
+
627
+ if self.frozen_stages >= 2:
628
+ self.pos_drop.eval()
629
+ for i in range(0, self.frozen_stages - 1):
630
+ m = self.layers[i]
631
+ m.eval()
632
+ for param in m.parameters():
633
+ param.requires_grad = False
634
+
635
+ def init_weights(self, pretrained=None):
636
+ """Initialize the weights in backbone.
637
+ Args:
638
+ pretrained (str, optional): Path to pre-trained weights.
639
+ Defaults to None.
640
+ """
641
+
642
+ def _init_weights(m):
643
+ if isinstance(m, nn.Linear):
644
+ trunc_normal_(m.weight, std=0.02)
645
+ if isinstance(m, nn.Linear) and m.bias is not None:
646
+ nn.init.constant_(m.bias, 0)
647
+ elif isinstance(m, nn.LayerNorm):
648
+ nn.init.constant_(m.bias, 0)
649
+ nn.init.constant_(m.weight, 1.0)
650
+
651
+ def forward(self, x):
652
+ """Forward function."""
653
+ x = self.patch_embed(x)
654
+
655
+ Wh, Ww = x.size(2), x.size(3)
656
+ if self.ape:
657
+ # interpolate the position embedding to the corresponding size
658
+ absolute_pos_embed = F.interpolate(
659
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
660
+ )
661
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
662
+ else:
663
+ x = x.flatten(2).transpose(1, 2)
664
+ x = self.pos_drop(x)
665
+
666
+ outs = {}
667
+ for i in range(self.num_layers):
668
+ layer = self.layers[i]
669
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
670
+
671
+ if i in self.out_indices:
672
+ norm_layer = getattr(self, f"norm{i}")
673
+ x_out = norm_layer(x_out)
674
+
675
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
676
+ outs["res{}".format(i + 2)] = out
677
+
678
+ return outs
679
+
680
+ def train(self, mode=True):
681
+ """Convert the model into training mode while keep layers freezed."""
682
+ super(SwinTransformer, self).train(mode)
683
+ self._freeze_stages()
684
+
685
+
686
+ @BACKBONE_REGISTRY.register()
687
+ class D2SwinTransformer(SwinTransformer, Backbone):
688
+ def __init__(self, cfg, input_shape):
689
+
690
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
691
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
692
+ in_chans = 3
693
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
694
+ depths = cfg.MODEL.SWIN.DEPTHS
695
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
696
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
697
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
698
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
699
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
700
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
701
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
702
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
703
+ norm_layer = nn.LayerNorm
704
+ ape = cfg.MODEL.SWIN.APE
705
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
706
+
707
+ super().__init__(
708
+ pretrain_img_size,
709
+ patch_size,
710
+ in_chans,
711
+ embed_dim,
712
+ depths,
713
+ num_heads,
714
+ window_size,
715
+ mlp_ratio,
716
+ qkv_bias,
717
+ qk_scale,
718
+ drop_rate,
719
+ attn_drop_rate,
720
+ drop_path_rate,
721
+ norm_layer,
722
+ ape,
723
+ patch_norm,
724
+ )
725
+
726
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
727
+
728
+ self._out_feature_strides = {
729
+ "res2": 4,
730
+ "res3": 8,
731
+ "res4": 16,
732
+ "res5": 32,
733
+ }
734
+ self._out_feature_channels = {
735
+ "res2": self.num_features[0],
736
+ "res3": self.num_features[1],
737
+ "res4": self.num_features[2],
738
+ "res5": self.num_features[3],
739
+ }
740
+
741
+ def forward(self, x):
742
+ """
743
+ Args:
744
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
745
+ Returns:
746
+ dict[str->Tensor]: names and the corresponding features
747
+ """
748
+ assert (
749
+ x.dim() == 4
750
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
751
+ outputs = {}
752
+ y = super().forward(x)
753
+ for k in y.keys():
754
+ if k in self._out_features:
755
+ outputs[k] = y[k]
756
+ return outputs
757
+
758
+ def output_shape(self):
759
+ return {
760
+ name: ShapeSpec(
761
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
762
+ )
763
+ for name in self._out_features
764
+ }
765
+
766
+ @property
767
+ def size_divisibility(self):
768
+ return 32
ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
3
+ """
4
+ MaskFormer criterion.
5
+ """
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from detectron2.utils.comm import get_world_size
11
+
12
+ from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
13
+
14
+
15
+ def dice_loss(inputs, targets, num_masks):
16
+ """
17
+ Compute the DICE loss, similar to generalized IOU for masks
18
+ Args:
19
+ inputs: A float tensor of arbitrary shape.
20
+ The predictions for each example.
21
+ targets: A float tensor with the same shape as inputs. Stores the binary
22
+ classification label for each element in inputs
23
+ (0 for the negative class and 1 for the positive class).
24
+ """
25
+ inputs = inputs.sigmoid()
26
+ inputs = inputs.flatten(1)
27
+ numerator = 2 * (inputs * targets).sum(-1)
28
+ denominator = inputs.sum(-1) + targets.sum(-1)
29
+ loss = 1 - (numerator + 1) / (denominator + 1)
30
+ return loss.sum() / num_masks
31
+
32
+
33
+ def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2):
34
+ """
35
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
36
+ Args:
37
+ inputs: A float tensor of arbitrary shape.
38
+ The predictions for each example.
39
+ targets: A float tensor with the same shape as inputs. Stores the binary
40
+ classification label for each element in inputs
41
+ (0 for the negative class and 1 for the positive class).
42
+ alpha: (optional) Weighting factor in range (0,1) to balance
43
+ positive vs negative examples. Default = -1 (no weighting).
44
+ gamma: Exponent of the modulating factor (1 - p_t) to
45
+ balance easy vs hard examples.
46
+ Returns:
47
+ Loss tensor
48
+ """
49
+ prob = inputs.sigmoid()
50
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
51
+ p_t = prob * targets + (1 - prob) * (1 - targets)
52
+ loss = ce_loss * ((1 - p_t) ** gamma)
53
+
54
+ if alpha >= 0:
55
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
56
+ loss = alpha_t * loss
57
+
58
+ return loss.mean(1).sum() / num_masks
59
+
60
+
61
+ class SetCriterion(nn.Module):
62
+ """This class computes the loss for DETR.
63
+ The process happens in two steps:
64
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
65
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
66
+ """
67
+
68
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
69
+ """Create the criterion.
70
+ Parameters:
71
+ num_classes: number of object categories, omitting the special no-object category
72
+ matcher: module able to compute a matching between targets and proposals
73
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
74
+ eos_coef: relative classification weight applied to the no-object category
75
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
76
+ """
77
+ super().__init__()
78
+ self.num_classes = num_classes
79
+ self.matcher = matcher
80
+ self.weight_dict = weight_dict
81
+ self.eos_coef = eos_coef
82
+ self.losses = losses
83
+ empty_weight = torch.ones(self.num_classes + 1)
84
+ empty_weight[-1] = self.eos_coef
85
+ self.register_buffer("empty_weight", empty_weight)
86
+
87
+ def loss_labels(self, outputs, targets, indices, num_masks):
88
+ """Classification loss (NLL)
89
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
90
+ """
91
+ assert "pred_logits" in outputs
92
+ src_logits = outputs["pred_logits"]
93
+
94
+ idx = self._get_src_permutation_idx(indices)
95
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
96
+ target_classes = torch.full(
97
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
98
+ )
99
+ target_classes[idx] = target_classes_o
100
+
101
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
102
+ losses = {"loss_ce": loss_ce}
103
+ return losses
104
+
105
+ def loss_masks(self, outputs, targets, indices, num_masks):
106
+ """Compute the losses related to the masks: the focal loss and the dice loss.
107
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
108
+ """
109
+ assert "pred_masks" in outputs
110
+
111
+ src_idx = self._get_src_permutation_idx(indices)
112
+ tgt_idx = self._get_tgt_permutation_idx(indices)
113
+ src_masks = outputs["pred_masks"]
114
+ src_masks = src_masks[src_idx]
115
+ masks = [t["masks"] for t in targets]
116
+ # TODO use valid to mask invalid areas due to padding in loss
117
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
118
+ target_masks = target_masks.to(src_masks)
119
+ target_masks = target_masks[tgt_idx]
120
+
121
+ # upsample predictions to the target size
122
+ src_masks = F.interpolate(
123
+ src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
124
+ )
125
+ src_masks = src_masks[:, 0].flatten(1)
126
+
127
+ target_masks = target_masks.flatten(1)
128
+ target_masks = target_masks.view(src_masks.shape)
129
+ losses = {
130
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
131
+ "loss_dice": dice_loss(src_masks, target_masks, num_masks),
132
+ }
133
+ return losses
134
+
135
+ def _get_src_permutation_idx(self, indices):
136
+ # permute predictions following indices
137
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
138
+ src_idx = torch.cat([src for (src, _) in indices])
139
+ return batch_idx, src_idx
140
+
141
+ def _get_tgt_permutation_idx(self, indices):
142
+ # permute targets following indices
143
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
144
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
145
+ return batch_idx, tgt_idx
146
+
147
+ def get_loss(self, loss, outputs, targets, indices, num_masks):
148
+ loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
149
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
150
+ return loss_map[loss](outputs, targets, indices, num_masks)
151
+
152
+ def forward(self, outputs, targets):
153
+ """This performs the loss computation.
154
+ Parameters:
155
+ outputs: dict of tensors, see the output specification of the model for the format
156
+ targets: list of dicts, such that len(targets) == batch_size.
157
+ The expected keys in each dict depends on the losses applied, see each loss' doc
158
+ """
159
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
160
+
161
+ # Retrieve the matching between the outputs of the last layer and the targets
162
+ indices = self.matcher(outputs_without_aux, targets)
163
+
164
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
165
+ num_masks = sum(len(t["labels"]) for t in targets)
166
+ num_masks = torch.as_tensor(
167
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
168
+ )
169
+ if is_dist_avail_and_initialized():
170
+ torch.distributed.all_reduce(num_masks)
171
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
172
+
173
+ # Compute all the requested losses
174
+ losses = {}
175
+ for loss in self.losses:
176
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
177
+
178
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
179
+ if "aux_outputs" in outputs:
180
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
181
+ indices = self.matcher(aux_outputs, targets)
182
+ for loss in self.losses:
183
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
184
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
185
+ losses.update(l_dict)
186
+
187
+ return losses
ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import fvcore.nn.weight_init as weight_init
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import configurable
11
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
12
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
13
+
14
+ from ..transformer.transformer_predictor import TransformerPredictor
15
+ from .pixel_decoder import build_pixel_decoder
16
+
17
+
18
+ @SEM_SEG_HEADS_REGISTRY.register()
19
+ class MaskFormerHead(nn.Module):
20
+
21
+ _version = 2
22
+
23
+ def _load_from_state_dict(
24
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
25
+ ):
26
+ version = local_metadata.get("version", None)
27
+ if version is None or version < 2:
28
+ # Do not warn if train from scratch
29
+ scratch = True
30
+ logger = logging.getLogger(__name__)
31
+ for k in list(state_dict.keys()):
32
+ newk = k
33
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
34
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
35
+ # logger.debug(f"{k} ==> {newk}")
36
+ if newk != k:
37
+ state_dict[newk] = state_dict[k]
38
+ del state_dict[k]
39
+ scratch = False
40
+
41
+ if not scratch:
42
+ logger.warning(
43
+ f"Weight format of {self.__class__.__name__} have changed! "
44
+ "Please upgrade your models. Applying automatic conversion now ..."
45
+ )
46
+
47
+ @configurable
48
+ def __init__(
49
+ self,
50
+ input_shape: Dict[str, ShapeSpec],
51
+ *,
52
+ num_classes: int,
53
+ pixel_decoder: nn.Module,
54
+ loss_weight: float = 1.0,
55
+ ignore_value: int = -1,
56
+ # extra parameters
57
+ transformer_predictor: nn.Module,
58
+ transformer_in_feature: str,
59
+ ):
60
+ """
61
+ NOTE: this interface is experimental.
62
+ Args:
63
+ input_shape: shapes (channels and stride) of the input features
64
+ num_classes: number of classes to predict
65
+ pixel_decoder: the pixel decoder module
66
+ loss_weight: loss weight
67
+ ignore_value: category id to be ignored during training.
68
+ transformer_predictor: the transformer decoder that makes prediction
69
+ transformer_in_feature: input feature name to the transformer_predictor
70
+ """
71
+ super().__init__()
72
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
73
+ self.in_features = [k for k, v in input_shape]
74
+ feature_strides = [v.stride for k, v in input_shape]
75
+ feature_channels = [v.channels for k, v in input_shape]
76
+
77
+ self.ignore_value = ignore_value
78
+ self.common_stride = 4
79
+ self.loss_weight = loss_weight
80
+
81
+ self.pixel_decoder = pixel_decoder
82
+ self.predictor = transformer_predictor
83
+ self.transformer_in_feature = transformer_in_feature
84
+
85
+ self.num_classes = num_classes
86
+
87
+ @classmethod
88
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
89
+ return {
90
+ "input_shape": {
91
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
92
+ },
93
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
94
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
95
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
96
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
97
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
98
+ "transformer_predictor": TransformerPredictor(
99
+ cfg,
100
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
101
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
102
+ else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
103
+ mask_classification=True,
104
+ ),
105
+ }
106
+
107
+ def forward(self, features):
108
+ return self.layers(features)
109
+
110
+ def layers(self, features):
111
+ mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
112
+ if self.transformer_in_feature == "transformer_encoder":
113
+ assert (
114
+ transformer_encoder_features is not None
115
+ ), "Please use the TransformerEncoderPixelDecoder."
116
+ predictions = self.predictor(transformer_encoder_features, mask_features)
117
+ else:
118
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features)
119
+ return predictions
ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
11
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
12
+
13
+ from ..transformer.transformer_predictor import TransformerPredictor
14
+ from .pixel_decoder import build_pixel_decoder
15
+
16
+
17
+ @SEM_SEG_HEADS_REGISTRY.register()
18
+ class PerPixelBaselineHead(nn.Module):
19
+
20
+ _version = 2
21
+
22
+ def _load_from_state_dict(
23
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
24
+ ):
25
+ version = local_metadata.get("version", None)
26
+ if version is None or version < 2:
27
+ logger = logging.getLogger(__name__)
28
+ # Do not warn if train from scratch
29
+ scratch = True
30
+ logger = logging.getLogger(__name__)
31
+ for k in list(state_dict.keys()):
32
+ newk = k
33
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
34
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
35
+ # logger.warning(f"{k} ==> {newk}")
36
+ if newk != k:
37
+ state_dict[newk] = state_dict[k]
38
+ del state_dict[k]
39
+ scratch = False
40
+
41
+ if not scratch:
42
+ logger.warning(
43
+ f"Weight format of {self.__class__.__name__} have changed! "
44
+ "Please upgrade your models. Applying automatic conversion now ..."
45
+ )
46
+
47
+ @configurable
48
+ def __init__(
49
+ self,
50
+ input_shape: Dict[str, ShapeSpec],
51
+ *,
52
+ num_classes: int,
53
+ pixel_decoder: nn.Module,
54
+ loss_weight: float = 1.0,
55
+ ignore_value: int = -1,
56
+ ):
57
+ """
58
+ NOTE: this interface is experimental.
59
+ Args:
60
+ input_shape: shapes (channels and stride) of the input features
61
+ num_classes: number of classes to predict
62
+ pixel_decoder: the pixel decoder module
63
+ loss_weight: loss weight
64
+ ignore_value: category id to be ignored during training.
65
+ """
66
+ super().__init__()
67
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
68
+ self.in_features = [k for k, v in input_shape]
69
+ feature_strides = [v.stride for k, v in input_shape]
70
+ feature_channels = [v.channels for k, v in input_shape]
71
+
72
+ self.ignore_value = ignore_value
73
+ self.common_stride = 4
74
+ self.loss_weight = loss_weight
75
+
76
+ self.pixel_decoder = pixel_decoder
77
+ self.predictor = Conv2d(
78
+ self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0
79
+ )
80
+ weight_init.c2_msra_fill(self.predictor)
81
+
82
+ @classmethod
83
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
84
+ return {
85
+ "input_shape": {
86
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
87
+ },
88
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
89
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
90
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
91
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
92
+ }
93
+
94
+ def forward(self, features, targets=None):
95
+ """
96
+ Returns:
97
+ In training, returns (None, dict of losses)
98
+ In inference, returns (CxHxW logits, {})
99
+ """
100
+ x = self.layers(features)
101
+ if self.training:
102
+ return None, self.losses(x, targets)
103
+ else:
104
+ x = F.interpolate(
105
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
106
+ )
107
+ return x, {}
108
+
109
+ def layers(self, features):
110
+ x, _ = self.pixel_decoder.forward_features(features)
111
+ x = self.predictor(x)
112
+ return x
113
+
114
+ def losses(self, predictions, targets):
115
+ predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
116
+ predictions = F.interpolate(
117
+ predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
118
+ )
119
+ loss = F.cross_entropy(
120
+ predictions, targets, reduction="mean", ignore_index=self.ignore_value
121
+ )
122
+ losses = {"loss_sem_seg": loss * self.loss_weight}
123
+ return losses
124
+
125
+
126
+ @SEM_SEG_HEADS_REGISTRY.register()
127
+ class PerPixelBaselinePlusHead(PerPixelBaselineHead):
128
+ def _load_from_state_dict(
129
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
130
+ ):
131
+ version = local_metadata.get("version", None)
132
+ if version is None or version < 2:
133
+ # Do not warn if train from scratch
134
+ scratch = True
135
+ logger = logging.getLogger(__name__)
136
+ for k in list(state_dict.keys()):
137
+ newk = k
138
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
139
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
140
+ logger.debug(f"{k} ==> {newk}")
141
+ if newk != k:
142
+ state_dict[newk] = state_dict[k]
143
+ del state_dict[k]
144
+ scratch = False
145
+
146
+ if not scratch:
147
+ logger.warning(
148
+ f"Weight format of {self.__class__.__name__} have changed! "
149
+ "Please upgrade your models. Applying automatic conversion now ..."
150
+ )
151
+
152
+ @configurable
153
+ def __init__(
154
+ self,
155
+ input_shape: Dict[str, ShapeSpec],
156
+ *,
157
+ # extra parameters
158
+ transformer_predictor: nn.Module,
159
+ transformer_in_feature: str,
160
+ deep_supervision: bool,
161
+ # inherit parameters
162
+ num_classes: int,
163
+ pixel_decoder: nn.Module,
164
+ loss_weight: float = 1.0,
165
+ ignore_value: int = -1,
166
+ ):
167
+ """
168
+ NOTE: this interface is experimental.
169
+ Args:
170
+ input_shape: shapes (channels and stride) of the input features
171
+ transformer_predictor: the transformer decoder that makes prediction
172
+ transformer_in_feature: input feature name to the transformer_predictor
173
+ deep_supervision: whether or not to add supervision to the output of
174
+ every transformer decoder layer
175
+ num_classes: number of classes to predict
176
+ pixel_decoder: the pixel decoder module
177
+ loss_weight: loss weight
178
+ ignore_value: category id to be ignored during training.
179
+ """
180
+ super().__init__(
181
+ input_shape,
182
+ num_classes=num_classes,
183
+ pixel_decoder=pixel_decoder,
184
+ loss_weight=loss_weight,
185
+ ignore_value=ignore_value,
186
+ )
187
+
188
+ del self.predictor
189
+
190
+ self.predictor = transformer_predictor
191
+ self.transformer_in_feature = transformer_in_feature
192
+ self.deep_supervision = deep_supervision
193
+
194
+ @classmethod
195
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
196
+ ret = super().from_config(cfg, input_shape)
197
+ ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE
198
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
199
+ in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
200
+ else:
201
+ in_channels = input_shape[ret["transformer_in_feature"]].channels
202
+ ret["transformer_predictor"] = TransformerPredictor(
203
+ cfg, in_channels, mask_classification=False
204
+ )
205
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
206
+ return ret
207
+
208
+ def forward(self, features, targets=None):
209
+ """
210
+ Returns:
211
+ In training, returns (None, dict of losses)
212
+ In inference, returns (CxHxW logits, {})
213
+ """
214
+ x, aux_outputs = self.layers(features)
215
+ if self.training:
216
+ if self.deep_supervision:
217
+ losses = self.losses(x, targets)
218
+ for i, aux_output in enumerate(aux_outputs):
219
+ losses["loss_sem_seg" + f"_{i}"] = self.losses(
220
+ aux_output["pred_masks"], targets
221
+ )["loss_sem_seg"]
222
+ return None, losses
223
+ else:
224
+ return None, self.losses(x, targets)
225
+ else:
226
+ x = F.interpolate(
227
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
228
+ )
229
+ return x, {}
230
+
231
+ def layers(self, features):
232
+ mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
233
+ if self.transformer_in_feature == "transformer_encoder":
234
+ assert (
235
+ transformer_encoder_features is not None
236
+ ), "Please use the TransformerEncoderPixelDecoder."
237
+ predictions = self.predictor(transformer_encoder_features, mask_features)
238
+ else:
239
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features)
240
+ if self.deep_supervision:
241
+ return predictions["pred_masks"], predictions["aux_outputs"]
242
+ else:
243
+ return predictions["pred_masks"], None
ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import logging
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import fvcore.nn.weight_init as weight_init
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import configurable
10
+ from detectron2.layers import Conv2d, ShapeSpec, get_norm
11
+ from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
12
+
13
+ from ..transformer.position_encoding import PositionEmbeddingSine
14
+ from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer
15
+
16
+
17
+ def build_pixel_decoder(cfg, input_shape):
18
+ """
19
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
20
+ """
21
+ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
22
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
23
+ forward_features = getattr(model, "forward_features", None)
24
+ if not callable(forward_features):
25
+ raise ValueError(
26
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
27
+ f"Please implement forward_features for {name} to only return mask features."
28
+ )
29
+ return model
30
+
31
+
32
+ @SEM_SEG_HEADS_REGISTRY.register()
33
+ class BasePixelDecoder(nn.Module):
34
+ # @configurable
35
+ def __init__(
36
+ self,
37
+ input_shape: Dict[str, ShapeSpec],
38
+ # *,
39
+ conv_dim: int,
40
+ mask_dim: int,
41
+ norm: Optional[Union[str, Callable]] = None,
42
+ ):
43
+ """
44
+ NOTE: this interface is experimental.
45
+ Args:
46
+ input_shape: shapes (channels and stride) of the input features
47
+ conv_dims: number of output channels for the intermediate conv layers.
48
+ mask_dim: number of output channels for the final conv layer.
49
+ norm (str or callable): normalization for all conv layers
50
+ """
51
+ super().__init__()
52
+
53
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
54
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
55
+ feature_channels = [v.channels for k, v in input_shape]
56
+
57
+ lateral_convs = []
58
+ output_convs = []
59
+
60
+ use_bias = norm == ""
61
+ for idx, in_channels in enumerate(feature_channels):
62
+ if idx == len(self.in_features) - 1:
63
+ output_norm = get_norm(norm, conv_dim)
64
+ output_conv = Conv2d(
65
+ in_channels,
66
+ conv_dim,
67
+ kernel_size=3,
68
+ stride=1,
69
+ padding=1,
70
+ bias=use_bias,
71
+ norm=output_norm,
72
+ activation=F.relu,
73
+ )
74
+ weight_init.c2_xavier_fill(output_conv)
75
+ self.add_module("layer_{}".format(idx + 1), output_conv)
76
+
77
+ lateral_convs.append(None)
78
+ output_convs.append(output_conv)
79
+ else:
80
+ lateral_norm = get_norm(norm, conv_dim)
81
+ output_norm = get_norm(norm, conv_dim)
82
+
83
+ lateral_conv = Conv2d(
84
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
85
+ )
86
+ output_conv = Conv2d(
87
+ conv_dim,
88
+ conv_dim,
89
+ kernel_size=3,
90
+ stride=1,
91
+ padding=1,
92
+ bias=use_bias,
93
+ norm=output_norm,
94
+ activation=F.relu,
95
+ )
96
+ weight_init.c2_xavier_fill(lateral_conv)
97
+ weight_init.c2_xavier_fill(output_conv)
98
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
99
+ self.add_module("layer_{}".format(idx + 1), output_conv)
100
+
101
+ lateral_convs.append(lateral_conv)
102
+ output_convs.append(output_conv)
103
+ # Place convs into top-down order (from low to high resolution)
104
+ # to make the top-down computation in forward clearer.
105
+ self.lateral_convs = lateral_convs[::-1]
106
+ self.output_convs = output_convs[::-1]
107
+
108
+ self.mask_dim = mask_dim
109
+ self.mask_features = Conv2d(
110
+ conv_dim,
111
+ mask_dim,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1,
115
+ )
116
+ weight_init.c2_xavier_fill(self.mask_features)
117
+
118
+ # @classmethod
119
+ # def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
120
+ # ret = {}
121
+ # ret["input_shape"] = {
122
+ # k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
123
+ # }
124
+ # ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
125
+ # ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
126
+ # ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
127
+ # return ret
128
+
129
+ def forward_features(self, features):
130
+ # Reverse feature maps into top-down order (from low to high resolution)
131
+ for idx, f in enumerate(self.in_features[::-1]):
132
+ x = features[f]
133
+ lateral_conv = self.lateral_convs[idx]
134
+ output_conv = self.output_convs[idx]
135
+ if lateral_conv is None:
136
+ y = output_conv(x)
137
+ else:
138
+ cur_fpn = lateral_conv(x)
139
+ # Following FPN implementation, we use nearest upsampling here
140
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
141
+ y = output_conv(y)
142
+ return self.mask_features(y), None
143
+
144
+ def forward(self, features, targets=None):
145
+ logger = logging.getLogger(__name__)
146
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
147
+ return self.forward_features(features)
148
+
149
+
150
+ class TransformerEncoderOnly(nn.Module):
151
+ def __init__(
152
+ self,
153
+ d_model=512,
154
+ nhead=8,
155
+ num_encoder_layers=6,
156
+ dim_feedforward=2048,
157
+ dropout=0.1,
158
+ activation="relu",
159
+ normalize_before=False,
160
+ ):
161
+ super().__init__()
162
+
163
+ encoder_layer = TransformerEncoderLayer(
164
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
165
+ )
166
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
167
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
168
+
169
+ self._reset_parameters()
170
+
171
+ self.d_model = d_model
172
+ self.nhead = nhead
173
+
174
+ def _reset_parameters(self):
175
+ for p in self.parameters():
176
+ if p.dim() > 1:
177
+ nn.init.xavier_uniform_(p)
178
+
179
+ def forward(self, src, mask, pos_embed):
180
+ # flatten NxCxHxW to HWxNxC
181
+ bs, c, h, w = src.shape
182
+ src = src.flatten(2).permute(2, 0, 1)
183
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
184
+ if mask is not None:
185
+ mask = mask.flatten(1)
186
+
187
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
188
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
189
+
190
+
191
+ @SEM_SEG_HEADS_REGISTRY.register()
192
+ class TransformerEncoderPixelDecoder(BasePixelDecoder):
193
+ @configurable
194
+ def __init__(
195
+ self,
196
+ input_shape: Dict[str, ShapeSpec],
197
+ *,
198
+ transformer_dropout: float,
199
+ transformer_nheads: int,
200
+ transformer_dim_feedforward: int,
201
+ transformer_enc_layers: int,
202
+ transformer_pre_norm: bool,
203
+ conv_dim: int,
204
+ mask_dim: int,
205
+ norm: Optional[Union[str, Callable]] = None,
206
+ ):
207
+ """
208
+ NOTE: this interface is experimental.
209
+ Args:
210
+ input_shape: shapes (channels and stride) of the input features
211
+ transformer_dropout: dropout probability in transformer
212
+ transformer_nheads: number of heads in transformer
213
+ transformer_dim_feedforward: dimension of feedforward network
214
+ transformer_enc_layers: number of transformer encoder layers
215
+ transformer_pre_norm: whether to use pre-layernorm or not
216
+ conv_dims: number of output channels for the intermediate conv layers.
217
+ mask_dim: number of output channels for the final conv layer.
218
+ norm (str or callable): normalization for all conv layers
219
+ """
220
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
221
+
222
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
223
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
224
+ feature_strides = [v.stride for k, v in input_shape]
225
+ feature_channels = [v.channels for k, v in input_shape]
226
+
227
+ in_channels = feature_channels[len(self.in_features) - 1]
228
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
229
+ weight_init.c2_xavier_fill(self.input_proj)
230
+ self.transformer = TransformerEncoderOnly(
231
+ d_model=conv_dim,
232
+ dropout=transformer_dropout,
233
+ nhead=transformer_nheads,
234
+ dim_feedforward=transformer_dim_feedforward,
235
+ num_encoder_layers=transformer_enc_layers,
236
+ normalize_before=transformer_pre_norm,
237
+ )
238
+ N_steps = conv_dim // 2
239
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
240
+
241
+ # update layer
242
+ use_bias = norm == ""
243
+ output_norm = get_norm(norm, conv_dim)
244
+ output_conv = Conv2d(
245
+ conv_dim,
246
+ conv_dim,
247
+ kernel_size=3,
248
+ stride=1,
249
+ padding=1,
250
+ bias=use_bias,
251
+ norm=output_norm,
252
+ activation=F.relu,
253
+ )
254
+ weight_init.c2_xavier_fill(output_conv)
255
+ delattr(self, "layer_{}".format(len(self.in_features)))
256
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
257
+ self.output_convs[0] = output_conv
258
+
259
+ @classmethod
260
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
261
+ ret = super().from_config(cfg, input_shape)
262
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
263
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
264
+ ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
265
+ ret[
266
+ "transformer_enc_layers"
267
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
268
+ ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
269
+ return ret
270
+
271
+ def forward_features(self, features):
272
+ # Reverse feature maps into top-down order (from low to high resolution)
273
+ for idx, f in enumerate(self.in_features[::-1]):
274
+ x = features[f]
275
+ lateral_conv = self.lateral_convs[idx]
276
+ output_conv = self.output_convs[idx]
277
+ if lateral_conv is None:
278
+ transformer = self.input_proj(x)
279
+ pos = self.pe_layer(x)
280
+ transformer = self.transformer(transformer, None, pos)
281
+ y = output_conv(transformer)
282
+ # save intermediate feature as input to Transformer decoder
283
+ transformer_encoder_features = transformer
284
+ else:
285
+ cur_fpn = lateral_conv(x)
286
+ # Following FPN implementation, we use nearest upsampling here
287
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
288
+ y = output_conv(y)
289
+ return self.mask_features(y), transformer_encoder_features
290
+
291
+ def forward(self, features, targets=None):
292
+ logger = logging.getLogger(__name__)
293
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
294
+ return self.forward_features(features)