zhoupans commited on
Commit
3c849be
1 Parent(s): 6c740f8

Upload 13 files

Browse files
Files changed (13) hide show
  1. Dockerfile +62 -0
  2. LICENSE +202 -0
  3. README.md +180 -8
  4. main.py +814 -0
  5. pretraining.sh +231 -0
  6. src/RandAugment.py +506 -0
  7. src/dataset.py +367 -0
  8. src/loss.py +244 -0
  9. src/model.py +607 -0
  10. src/multicropdataset.py +445 -0
  11. src/optimizer.py +210 -0
  12. src/vision_transformer.py +491 -0
  13. utils.py +583 -0
Dockerfile ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.0-base-ubuntu20.04
2
+
3
+ ENV DEBIAN_FRONTEND noninteractive
4
+
5
+ ENV CUDNN_VERSION=8.0.5.39-1+cuda11.1
6
+ ENV NCCL_VERSION=2.7.8-1+cuda11.1
7
+
8
+ ARG python=3.8
9
+ ENV PYTHON_VERSION=${python}
10
+
11
+ # Set default shell to /bin/bash
12
+ SHELL ["/bin/bash", "-cu"]
13
+
14
+ RUN apt-get update && apt-get install -y --allow-downgrades \
15
+ --allow-change-held-packages --no-install-recommends \
16
+ build-essential \
17
+ cmake \
18
+ git \
19
+ curl \
20
+ vim \
21
+ wget \
22
+ ca-certificates \
23
+ libcudnn8=${CUDNN_VERSION} \
24
+ libnccl2=${NCCL_VERSION} \
25
+ libnccl-dev=${NCCL_VERSION} \
26
+ libjpeg-dev \
27
+ libpng-dev \
28
+ python${PYTHON_VERSION} \
29
+ python${PYTHON_VERSION}-dev \
30
+ python${PYTHON_VERSION}-distutils \
31
+ librdmacm1 \
32
+ libibverbs1 \
33
+ ibverbs-providers
34
+
35
+ RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
36
+
37
+ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
38
+ python get-pip.py && \
39
+ rm get-pip.py
40
+
41
+ RUN /usr/bin/python -m pip install --upgrade pip
42
+
43
+ # Install pytorch
44
+ RUN pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 \
45
+ -f https://download.pytorch.org/whl/torch_stable.html
46
+
47
+ RUN pip install tensorboard==2.5.0
48
+ RUN pip install tensorboard-data-server==0.6.1
49
+ RUN pip install tensorboard-plugin-wit==1.8.0
50
+ RUN pip install tensorboardX==1.8
51
+
52
+ RUN pip install timm==0.4.5
53
+ RUN pip install opencv-contrib-python-headless==4.5.2.54
54
+ RUN pip install tqdm==4.61.2
55
+ RUN pip install PyYAML==5.4.1
56
+ RUN pip install Pillow==8.3.1
57
+ RUN pip install einops==0.3.0
58
+ RUN pip install scipy==1.7.1
59
+
60
+
61
+
62
+
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2022 Garena Online Private Limited
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,8 +1,180 @@
1
- ---
2
- license: apache-2.0
3
- datasets:
4
- - imagenet-1k
5
- metrics:
6
- - accuracy
7
- pipeline_tag: image-classification
8
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mugs: A Multi-Granular Self-Supervised Learning Framework
2
+
3
+ This is a PyTorch implementation of **Mugs** proposed by our paper "**Mugs: A Multi-Granular Self-Supervised Learning Framework**". [![arXiv](https://img.shields.io/badge/arXiv-2203.14415-b31b1b.svg?style=flat)](http://arxiv.org/abs/2203.14415)
4
+
5
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mugs-a-multi-granular-self-supervised/self-supervised-image-classification-on)](https://paperswithcode.com/sota/self-supervised-image-classification-on?p=mugs-a-multi-granular-self-supervised)
6
+
7
+ <div align="center">
8
+ <img width="100%" alt="Overall framework of Mugs. " src="./exp_illustration/framework.png">
9
+ </div>
10
+
11
+ **<p align="center">Fig 1. Overall framework of Mugs.** In (a), for each image, two random crops of one image
12
+ are fed into backbones of student and teacher. Three granular supervisions: 1) instance discrimination supervision, 2) local-group discrimination
13
+ supervision, and 3) group discrimination supervision, are adopted to learn multi-granular representation. In (b), local-group modules in
14
+ student/teacher averages all patch tokens, and finds top-k neighbors from memory buffer to aggregate them with the average for obtaining a local-group feature.</p>
15
+
16
+
17
+
18
+
19
+
20
+ # Pretrained models on ImageNet-1K
21
+
22
+ You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks.
23
+ **<p align="center">Table 1. KNN and linear probing performance with their corresponding hyper-parameters, logs and model weights.</p>**
24
+ <table>
25
+ <tr>
26
+ <th>arch</th>
27
+ <th>params</th>
28
+ <th>pretraining epochs</th>
29
+ <th>k-nn</th>
30
+ <th>linear</th>
31
+ <th colspan="6">download</th>
32
+ </tr>
33
+ <tr>
34
+ <td>ViT-S/16</td>
35
+ <td>21M</td>
36
+ <td>100</td>
37
+ <td>72.3%</td>
38
+ <td>76.4%</td>
39
+ <td><a href="https://drive.google.com/file/d/1V2TyArzr7qY93UFglPBHRfYVyAMEfsHR/view?usp=sharing">backbone only</a></td>
40
+ <td><a href="https://drive.google.com/file/d/1AePcCeUEhK0nb9syQKufqqnhpUEr9Rji/view?usp=sharing">full ckpt</a></td>
41
+ <td><a href="https://drive.google.com/file/d/17phHQx88f4_xSqkPtIYvUoUE2U-1tovg/view?usp=sharing">args</a></td>
42
+ <td><a href="https://drive.google.com/file/d/1UBMTB-C3BnNKT5939fhSstHc9H30Vizd/view?usp=sharing">logs</a></td>
43
+ <td><a href="https://drive.google.com/file/d/1MkXctkgqEXjWWRs4Cz5CyTTx_IHDOP4G/view?usp=sharing">eval logs</a></td>
44
+ </tr>
45
+ <tr>
46
+ <td>ViT-S/16</td>
47
+ <td>21M</td>
48
+ <td>300</td>
49
+ <td>74.8%</td>
50
+ <td>78.2%</td>
51
+ <td><a href="https://drive.google.com/file/d/1ZAPQ0HiDZO5Uk7jVqF46H6VbGxunZkuf/view?usp=sharing">backbone only</a></td>
52
+ <td><a href="https://drive.google.com/file/d/1EO-_kYlAt23qgFYZF2u-KLks5js9LvrZ/view?usp=sharing">full ckpt</a></td>
53
+ <td><a href="https://drive.google.com/file/d/1b6zLZ3r_mZbk17SvhJIZF2VCoYVbJUnU/view?usp=sharing">args</a></td>
54
+ <td><a href="https://drive.google.com/file/d/1L7VzH1rztoraBCBNVWL-Y8k7Y8PFU773/view?usp=sharing">logs</a></td>
55
+ <td><a href="https://drive.google.com/file/d/1KgnX8ReXIVsu65_-p7NWPH8S0HEDPMUU/view?usp=sharing">eval logs</a></td>
56
+ </tr>
57
+ <tr>
58
+ <td>ViT-S/16</td>
59
+ <td>21M</td>
60
+ <td>800</td>
61
+ <td>75.6%</td>
62
+ <td>78.9%</td>
63
+ <td><a href="https://drive.google.com/file/d/1KMdhxxWc2JXAiFqVxX584V4RvlJgckGq/view?usp=sharing">backbone only</a></td>
64
+ <td><a href="https://drive.google.com/file/d/1FBaOt0Rjxm6yyJadttOyN6hSh8ueZ0dh/view?usp=sharing">full ckpt</a></td>
65
+ <td><a href="https://drive.google.com/file/d/19Ma-eSIgdwLoBg6wBXeFiW46zCI2EHvH/view?usp=sharing">args</a></td>
66
+ <td><a href="https://drive.google.com/file/d/1wX4AUO5NBVZUb8jN1iGBRkS17sszb4_O/view?usp=sharing">logs</a></td>
67
+ <td><a href="https://drive.google.com/file/d/12tiO4glWZNB044TYiPPCfbnUX_9AbqVc/view?usp=sharing">eval logs</a></td>
68
+ </tr>
69
+ <tr>
70
+ <td>ViT-B/16</td>
71
+ <td>85M</td>
72
+ <td>400</td>
73
+ <td>78.0%</td>
74
+ <td>80.6%</td>
75
+ <td><a href="https://drive.google.com/file/d/13NUziwToBXBmS7n7V_1Z5N6EG_7bcncW/view?usp=sharing">backbone only</a></td>
76
+ <td><a href="https://drive.google.com/file/d/1M41TVVFyVRDTK5kbgLCEImrxw0AVtebb/view?usp=sharing">full ckpt</a></td>
77
+ <td><a href="https://drive.google.com/file/d/1-5fB5ZCVQAfxTXZ6ro56AVkhb3whpaJc/view?usp=sharing">args</a></td>
78
+ <td><a href="https://drive.google.com/file/d/11RlCx6eViRnFD6gBlr_lOOxOhu-L6l6D/view?usp=sharing">logs</a></td>
79
+ <td><a href="https://drive.google.com/file/d/1gOR250QFLZfe40pLNPcOqaLPAnKLuE_C/view?usp=sharing">eval logs</a></td>
80
+ </tr>
81
+ <tr>
82
+ <td>ViT-L/16</td>
83
+ <td>307M</td>
84
+ <td>250</td>
85
+ <td>80.3%</td>
86
+ <td>82.1%</td>
87
+ <td><a href="https://drive.google.com/file/d/1K76a-YnFYcmDXUZ_UlYVYFrWOt2a6733/view?usp=sharing">backbone only</a></td>
88
+ <td><a href="https://drive.google.com/file/d/1Q5Ukvucx44YawyOhMEAY13Ppb8OOWOAB/view?usp=sharing">full ckpt</a></td>
89
+ <td><a href="https://drive.google.com/file/d/1p8XhaA2_Zbejm__UT8iNKG8r5tzS9c6c/view?usp=sharing">args</a></td>
90
+ <td><a href="https://drive.google.com/file/d/1JLVcUNfkyBI0BcMm7OpNU_3KTxIABK0Z/view?usp=sharing">logs</a></td>
91
+ <td><a href="https://drive.google.com/file/d/1rqWenRFN0czat_55GY9GNOu7gS6fww3g/view?usp=sharing">eval logs</a></td>
92
+ </tr>
93
+ </table>
94
+
95
+ <div align="center">
96
+ <img width="100%" alt="Comparison of linear probing accuracy on ImageNet-1K." src="./exp_illustration/comparison.png">
97
+ </div>
98
+
99
+ **<p align="center">Fig 2. Comparison of linear probing accuracy on ImageNet-1K.**</p>
100
+
101
+ ## Pretraining Settings
102
+
103
+ ### Environment
104
+ For reproducing, please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset.
105
+ This codebase has been developed with python version 3.8, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. For the full
106
+ environment, please refer to our `Dockerfile` file.
107
+
108
+
109
+ ### ViT pretraining :beer:
110
+ To pretraining each model, please find the exact hyper-parameter settings at the `args` column of [Table 1](https://github.com/sail-sg/mugs). For training log and linear probing log, please refer to the
111
+ `log` and `eval logs` column of [Table 1](https://github.com/sail-sg/mugs).
112
+
113
+ #### ViT-Small pretraining:
114
+ To run ViT-small for 100 epochs, we use two nodes of total 8 A100 GPUs (total 512 minibatch size) by using following command:
115
+ ```
116
+ python -m torch.distributed.launch --nproc_per_node=8 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small
117
+ --group_teacher_temp 0.04 --group_warmup_teacher_temp_epochs 0 --weight_decay_end 0.2 --norm_last_layer false --epochs 100
118
+ ```
119
+ To run ViT-small for 300 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:
120
+ ```
121
+ python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small
122
+ --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 300
123
+ ```
124
+ To run ViT-small for 800 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:
125
+ ```
126
+ python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small
127
+ --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 800
128
+ ```
129
+
130
+ #### ViT-Base pretraining:
131
+ To run ViT-base for 400 epochs, we use two nodes of total 24 A100 GPUs (total 1024 minibatch size) by using following command:
132
+ ```
133
+ python -m torch.distributed.launch --nproc_per_node=24 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_base
134
+ --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --min_lr 2e-06 --weight_decay_end 0.1 --freeze_last_layer 3 --norm_last_layer
135
+ false --epochs 400
136
+ ```
137
+
138
+ #### ViT-Large pretraining:
139
+ To run ViT-large for 250 epochs, we use two nodes of total 40 A100 GPUs (total 640 minibatch size) by using following command:
140
+ ```
141
+ python -m torch.distributed.launch --nproc_per_node=40 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_large
142
+ --lr 0.0015 --min_lr 1.5e-4 --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --weight_decay 0.025
143
+ --weight_decay_end 0.08 --norm_last_layer true --drop_path_rate 0.3 --freeze_last_layer 3 --epochs 250
144
+ ```
145
+
146
+ ## Evaluation
147
+ We are cleaning up the evalutation code and will release them when they are ready.
148
+
149
+ ## Self-attention visualization
150
+ Here we provide the self-attention map of the [CLS] token on the heads of the last layer
151
+ <div align="center">
152
+ <img width="100%" alt="Self-attention from a ViT-Base/16 trained with Mugs" src="./exp_illustration/attention_vis.png">
153
+ </div>
154
+
155
+ **<p align="center">Fig 3. Self-attention from a ViT-Base/16 trained with Mugs.**</p>
156
+
157
+
158
+ ## T-SNE visualization
159
+ Here we provide the T-SNE visualization of the learned feature by ViT-B/16.
160
+ We show the fish classes in ImageNet-1K, i.e., the first six classes,
161
+ including tench, goldfish, white shark, tiger shark, hammerhead, electric
162
+ ray. See more examples in Appendix.
163
+ <div align="center">
164
+ <img width="100%" alt="T-SNE visualization of the learned feature by ViT-B/16." src="./exp_illustration/TSNE.png">
165
+ </div>
166
+
167
+ **<p align="center">Fig 4. T-SNE visualization of the learned feature by ViT-B/16.**</p>
168
+
169
+ ## License
170
+ This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
171
+
172
+ ## Citation
173
+ If you find this repository useful, please consider giving a star :star: and citation :beer::
174
+ ```
175
+ @inproceedings{mugs2022SSL,
176
+ title={Mugs: A Multi-Granular Self-Supervised Learning Framework},
177
+ author={Pan Zhou and Yichen Zhou and Chenyang Si and Weihao Yu and Teck Khim Ng and Shuicheng Yan},
178
+ booktitle={arXiv preprint arXiv:2203.14415},
179
+ year={2022}
180
+ }
main.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mugs training code
16
+ """
17
+ import argparse
18
+ import datetime
19
+ import json
20
+ import math
21
+ import os
22
+ import sys
23
+ import time
24
+ from collections import OrderedDict
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ import torch.backends.cudnn as cudnn
29
+ import torch.nn as nn
30
+ from torchvision import models as torchvision_models
31
+
32
+ import utils
33
+ from src.loss import get_multi_granular_loss
34
+ from src.model import get_model
35
+ from src.multicropdataset import data_prefetcher, get_dataset
36
+ from src.optimizer import cancel_gradients_last_layer, get_optimizer, clip_gradients
37
+
38
+ torchvision_archs = sorted(
39
+ name
40
+ for name in torchvision_models.__dict__
41
+ if name.islower()
42
+ and not name.startswith("__")
43
+ and callable(torchvision_models.__dict__[name])
44
+ )
45
+
46
+
47
+ def get_args_parser():
48
+ parser = argparse.ArgumentParser("Mugs", add_help=False)
49
+
50
+ ##======== Model parameters ============
51
+ parser.add_argument(
52
+ "--arch",
53
+ type=str,
54
+ default="vit_small",
55
+ choices=["vit_small", "vit_base", "vit_large"],
56
+ help="""Name of architecture to train.""",
57
+ )
58
+ parser.add_argument(
59
+ "--patch_size",
60
+ type=int,
61
+ default=16,
62
+ help="""Size in pixels
63
+ of input square patches - default 16 (for 16x16 patches). Using smaller
64
+ values leads to better performance but requires more memory. Applies only
65
+ for ViTs (vit_small and vit_base). If <16, we recommend disabling
66
+ mixed precision training (--use_fp16 false) to avoid unstabilities.""",
67
+ )
68
+
69
+ ##======== Training/Optimization parameters ============
70
+ parser.add_argument(
71
+ "--momentum_teacher",
72
+ type=float,
73
+ default=0.996,
74
+ help="""Base EMA
75
+ parameter for teacher update. The value is increased to 1 during training with
76
+ cosine schedule. We recommend setting a higher value with small batches: for
77
+ example use 0.9995 with batch size of 256.""",
78
+ )
79
+ parser.add_argument(
80
+ "--use_fp16",
81
+ type=utils.bool_flag,
82
+ default=False,
83
+ help="""Whether or not
84
+ to use half precision for training. Improves training time and memory requirements,
85
+ but can provoke instability and slight decay of performance. We recommend disabling
86
+ mixed precision if the loss is unstable, if reducing the patch size or if training
87
+ with bigger ViTs.""",
88
+ )
89
+ parser.add_argument(
90
+ "--weight_decay",
91
+ type=float,
92
+ default=0.04,
93
+ help="""Initial value of the
94
+ weight decay. With ViT, a smaller value at the beginning of training works well.""",
95
+ )
96
+ parser.add_argument(
97
+ "--weight_decay_end",
98
+ type=float,
99
+ default=0.2,
100
+ help="""Final value of the
101
+ weight decay. We use a cosine schedule for WD and using a larger decay by
102
+ the end of training improves performance for ViTs.""",
103
+ )
104
+ parser.add_argument(
105
+ "--clip_grad",
106
+ type=float,
107
+ default=3.0,
108
+ help="""Maximal parameter
109
+ gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
110
+ help optimization for larger ViT architectures. 0 for disabling.""",
111
+ )
112
+ parser.add_argument(
113
+ "--batch_size_per_gpu",
114
+ type=int,
115
+ default=64,
116
+ help="Per-GPU batch-size : number of distinct images loaded on one GPU.",
117
+ )
118
+ parser.add_argument(
119
+ "--epochs", type=int, default=100, help="Number of epochs of training."
120
+ )
121
+ parser.add_argument(
122
+ "--warmup_epochs",
123
+ default=10,
124
+ type=int,
125
+ help="""Number of epochs for the linear learning-rate warm up.=""",
126
+ )
127
+ parser.add_argument(
128
+ "--freeze_last_layer",
129
+ type=int,
130
+ default=1,
131
+ help="""Number of epochs during
132
+ which we keep the output layer fixed for the group supervision loss. Typically doing so during
133
+ the first epoch helps training. Try increasing this value if the loss does not decrease.""",
134
+ )
135
+ parser.add_argument(
136
+ "--lr",
137
+ type=float,
138
+ default=0.0008,
139
+ help="""Learning rate at the end of
140
+ linear warmup (highest LR used during training). The learning rate is linearly scaled
141
+ with the batch size, and specified here for a reference batch size of 256.""",
142
+ )
143
+ parser.add_argument(
144
+ "--patch_embed_lr_mult",
145
+ type=float,
146
+ default=0.2,
147
+ help="""For patch
148
+ embedding layer, its learning rate is lr * patch_embed_lr_mult (<1.0) in most case, which
149
+ stables training and also slightly improve the performance.""",
150
+ )
151
+ parser.add_argument(
152
+ "--min_lr",
153
+ type=float,
154
+ default=1e-6,
155
+ help="""Target LR at the
156
+ end of optimization. We use a cosine LR schedule with linear warmup.""",
157
+ )
158
+ parser.add_argument(
159
+ "--optimizer",
160
+ type=str,
161
+ default="adamw",
162
+ choices=["adamw", "sgd", "lars"],
163
+ help="""Type of optimizer. We recommend using adamw
164
+ with ViTs.""",
165
+ )
166
+ parser.add_argument(
167
+ "--drop_path_rate", type=float, default=0.1, help="""stochastic depth rate"""
168
+ )
169
+
170
+ ##======== Multi-granular supervisions (instance/local-group/group supervisions) ==========
171
+ parser.add_argument(
172
+ "--loss_weights",
173
+ type=float,
174
+ nargs="+",
175
+ default=[1.0, 1.0, 1.0],
176
+ help="""three loss weights for instance, local-group, group supervision losses in turn""",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--use_bn_in_head",
181
+ type=utils.bool_flag,
182
+ default=False,
183
+ help="Whether to use batch normalizations in the three projection heads (Default: False)",
184
+ )
185
+ parser.add_argument(
186
+ "--norm_before_pred",
187
+ type=utils.bool_flag,
188
+ default=True,
189
+ help="""Whether to use batch normalizations after projection heads (namely before
190
+ prediction heads) in instance and local-group supervisions. (Default: False)""",
191
+ )
192
+
193
+ # parameters for instance discrimination supervision
194
+ parser.add_argument(
195
+ "--instance_out_dim",
196
+ type=int,
197
+ default=256,
198
+ help="""output dimention in the projection and prediction heads.""",
199
+ )
200
+ parser.add_argument(
201
+ "--instance_queue_size",
202
+ type=int,
203
+ default=65536,
204
+ help="""the queue size of the memory to store the negative keys.""",
205
+ )
206
+ parser.add_argument(
207
+ "--instance_temp",
208
+ type=float,
209
+ default=0.2,
210
+ help="""the temperature parameters for the infoNCE loss in instance supervision.""",
211
+ )
212
+
213
+ # parameters for local-group discrimination supervision
214
+ parser.add_argument(
215
+ "--local_group_out_dim",
216
+ type=int,
217
+ default=256,
218
+ help="""output dimention in the projection and prediction heads.""",
219
+ )
220
+ parser.add_argument(
221
+ "--local_group_knn_top_n",
222
+ type=int,
223
+ default=8,
224
+ help="how many neighbors we use to aggregate for a local-group",
225
+ )
226
+ parser.add_argument(
227
+ "--local_group_queue_size",
228
+ type=int,
229
+ default=65536,
230
+ help="""the queue sizes of the memory to store the negative keys for infoNCE loss and
231
+ another memory size to store the weak augmentated samples for local-group aggregation.""",
232
+ )
233
+ parser.add_argument(
234
+ "--local_group_temp",
235
+ type=float,
236
+ default=0.2,
237
+ help="""the temperature parameters for the infoNCE loss in instance supervision.""",
238
+ )
239
+
240
+ ## parameters for group discrimination supervision
241
+ parser.add_argument(
242
+ "--group_out_dim",
243
+ type=int,
244
+ default=65536,
245
+ help="""output dimention in the prediction heads.""",
246
+ )
247
+ parser.add_argument(
248
+ "--group_bottleneck_dim",
249
+ type=float,
250
+ default=256,
251
+ help="""head bottleneck dimention in the prediction heads.""",
252
+ )
253
+ parser.add_argument(
254
+ "--norm_last_layer",
255
+ type=utils.bool_flag,
256
+ default=True,
257
+ help="""Whether or not to weight normalize the last layer of the group supervision head.
258
+ Not normalizing leads to better performance but can make the training unstable. We
259
+ typically set this paramater to False with vit_small and True with vit_base and vit_large.""",
260
+ )
261
+
262
+ parser.add_argument(
263
+ "--group_student_temp",
264
+ type=float,
265
+ default=0.1,
266
+ help="""the temperature parameters for the clustering loss in student output.""",
267
+ )
268
+ parser.add_argument(
269
+ "--group_warmup_teacher_temp",
270
+ default=0.04,
271
+ type=float,
272
+ help="""Initial value for the teacher temperature: 0.04 works well in most cases.
273
+ Try decreasing it if the training loss does not decrease.""",
274
+ )
275
+ parser.add_argument(
276
+ "--group_teacher_temp",
277
+ default=0.04,
278
+ type=float,
279
+ help="""Final value
280
+ (after linear warmup) of the teacher temperature. For most experiments, anything above
281
+ 0.07 is unstable. We recommend starting with the default value of 0.04 and increase
282
+ this slightly if needed.""",
283
+ )
284
+ parser.add_argument(
285
+ "--group_warmup_teacher_temp_epochs",
286
+ default=0,
287
+ type=int,
288
+ help="""Number of warmup epochs for the teacher temperature (Default: 30).""",
289
+ )
290
+
291
+ ##======== augmentation parameters ============
292
+ # Multi-crop parameters
293
+ parser.add_argument(
294
+ "--global_crops_scale",
295
+ type=float,
296
+ nargs="+",
297
+ default=(0.25, 1.0),
298
+ help="""Scale range of the cropped image before resizing, relatively to the origin image.
299
+ Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
300
+ recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""",
301
+ )
302
+ parser.add_argument(
303
+ "--local_crops_number",
304
+ type=int,
305
+ default=10,
306
+ help="""Number of small
307
+ local views to generate. Set this parameter to 0 to disable multi-crop training.
308
+ When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """,
309
+ )
310
+ parser.add_argument(
311
+ "--local_crops_scale",
312
+ type=float,
313
+ nargs="+",
314
+ default=(0.05, 0.25),
315
+ help="""Scale range of the cropped image before resizing, relatively to the origin image.
316
+ Used for small local view cropping of multi-crop.""",
317
+ )
318
+ # strong augmentation parameters
319
+ parser.add_argument(
320
+ "--timm_auto_augment_par",
321
+ type=str,
322
+ default="rand-m9-mstd0.5-inc1",
323
+ help="""the parameters for the AutoAugment used in DeiT.""",
324
+ )
325
+ parser.add_argument(
326
+ "--color_aug",
327
+ type=utils.bool_flag,
328
+ default=False,
329
+ help="""after AutoAugment, whether we further perform color augmentation. (Default: False).""",
330
+ )
331
+ parser.add_argument(
332
+ "--size_crops",
333
+ type=int,
334
+ default=[96],
335
+ nargs="+",
336
+ help="""the small crop size. Note we use multi-crop strategy, namely two 224-sized crops +
337
+ ten 96-sized crops. (Default: 96)""",
338
+ )
339
+ parser.add_argument(
340
+ "--strong_ratio",
341
+ type=float,
342
+ default=0.45,
343
+ help="""the ratio of image augmentation for the AutoAugment used in DeiT.""",
344
+ )
345
+ parser.add_argument(
346
+ "--re_prob",
347
+ type=float,
348
+ default=0.25,
349
+ help="""the re-prob parameter of image augmentation for the AutoAugment used in DeiT.""",
350
+ )
351
+ parser.add_argument(
352
+ "--vanilla_weak_augmentation",
353
+ type=utils.bool_flag,
354
+ default=False,
355
+ help="""Whether we use the same augmentation in DINO, namely only using weak augmentation.""",
356
+ )
357
+ parser.add_argument(
358
+ "--prob",
359
+ type=float,
360
+ default=0.5,
361
+ help="""When we use strong augmentation and weak augmentation, the ratio of images to
362
+ be cropped with strong augmentation.""",
363
+ )
364
+
365
+ ##======== Misc ============
366
+ parser.add_argument(
367
+ "--data_path",
368
+ default="/dataset/imageNet100_sicy/train/",
369
+ type=str,
370
+ help="""Please specify path to the ImageNet training data.""",
371
+ )
372
+ parser.add_argument(
373
+ "--output_dir",
374
+ default="./exp/",
375
+ type=str,
376
+ help="""Path to save logs and checkpoints.""",
377
+ )
378
+ parser.add_argument(
379
+ "--saveckp_freq",
380
+ default=50,
381
+ type=int,
382
+ help="""Save checkpoint every x epochs.""",
383
+ )
384
+ parser.add_argument("--seed", default=0, type=int, help="""Random seed.""")
385
+ parser.add_argument(
386
+ "--num_workers",
387
+ default=12,
388
+ type=int,
389
+ help="""Number of data loading workers per GPU.""",
390
+ )
391
+ parser.add_argument(
392
+ "--dist_url",
393
+ default="env://",
394
+ type=str,
395
+ help="""url used to set up
396
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""",
397
+ )
398
+ parser.add_argument(
399
+ "--local_rank",
400
+ default=0,
401
+ type=int,
402
+ help="""local rank for distrbuted training.""",
403
+ )
404
+ parser.add_argument(
405
+ "--rank", default=0, type=int, help="""rank for distrbuted training."""
406
+ )
407
+ parser.add_argument(
408
+ "--world_size",
409
+ default=1,
410
+ type=int,
411
+ help="""world size for distrbuted training.""",
412
+ )
413
+
414
+ parser.add_argument(
415
+ "--use_prefetcher",
416
+ type=utils.bool_flag,
417
+ default=True,
418
+ help="""whether we use prefetcher which can accerelate the training speed.""",
419
+ )
420
+ parser.add_argument(
421
+ "--debug",
422
+ type=utils.bool_flag,
423
+ default=False,
424
+ help="""whether we debug. if yes, we only load small fraction of training data to reduce data reading time.""",
425
+ )
426
+ parser.add_argument(
427
+ "--ddpjob",
428
+ default=False,
429
+ type=utils.bool_flag,
430
+ help="""whether we use ddp job. We suggest to use it for distributed training. For single GPUs
431
+ or Node, you can close it.""",
432
+ )
433
+
434
+ return parser
435
+
436
+
437
+ def train_mugs(args):
438
+ """
439
+ main training code for Mugs, including building dataloader, models, losses, optimizers, etc
440
+ """
441
+ ##======== prepare logger for more detailed logs ============
442
+ logger = utils.get_logger(args.output_dir + "/train.log")
443
+ logger.info(args)
444
+ if args.output_dir and utils.is_main_process():
445
+ with (Path(args.output_dir) / "log.txt").open("a") as f:
446
+ f.write(str(args) + "\n")
447
+
448
+ ##======== initilize distribution ============
449
+ if args.ddpjob is True:
450
+ utils.init_distributed_ddpjob(args)
451
+ else:
452
+ utils.init_distributed_mode(args)
453
+
454
+ ##======== fix seed for reproduce ============
455
+ utils.fix_random_seeds(args.seed)
456
+ print("git:\n {}\n".format(utils.get_sha()))
457
+ print(
458
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
459
+ )
460
+ cudnn.benchmark = True
461
+ cudnn.deterministic = True
462
+
463
+ ##======== get the training dataset/loader ============
464
+ data_loader = get_dataset(args)
465
+ logger.info(f"Data loaded: there are {len(data_loader.dataset)} images.")
466
+
467
+ ##====== build student and teacher networks (vit_small, vit_base, vit_large) =========
468
+ student, teacher, student_mem, teacher_mem = get_model(args)
469
+
470
+ # move networks to gpu
471
+ student, teacher = student.cuda(), teacher.cuda()
472
+ student_mem, teacher_mem = student_mem.cuda(), teacher_mem.cuda()
473
+
474
+ # synchronize batch norms (if any)
475
+ if utils.has_batchnorms(student):
476
+ student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
477
+ teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
478
+ # we need DDP wrapper to have synchro batch norms working...
479
+ teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
480
+ teacher_without_ddp = teacher.module
481
+ else:
482
+ # teacher_without_ddp and teacher are the same thing
483
+ teacher_without_ddp = teacher
484
+ student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
485
+ # teacher and student start with the same weights
486
+ teacher_without_ddp.load_state_dict(student.module.state_dict(), strict=False)
487
+
488
+ # there is no backpropagation through the teacher, so no need for gradients
489
+ for p in teacher.parameters():
490
+ p.requires_grad = False
491
+ print(f"Student and Teacher are built: they are both {args.arch} network.")
492
+
493
+ ##======== get multi granular losses and their loss weights ============
494
+ all_losses, all_weights = get_multi_granular_loss(args)
495
+
496
+ ##======== preparing optimizer ============
497
+ optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule = get_optimizer(
498
+ student, len(data_loader), args
499
+ )
500
+
501
+ ##======== optionally resume training ============
502
+ to_restore = {"epoch": 0}
503
+ utils.restart_from_checkpoint(
504
+ os.path.join(args.output_dir, "checkpoint.pth"),
505
+ run_variables=to_restore,
506
+ student=student,
507
+ teacher=teacher,
508
+ optimizer=optimizer,
509
+ fp16_scaler=fp16_scaler,
510
+ student_mem=student_mem,
511
+ teacher_mem=teacher_mem,
512
+ **all_losses,
513
+ )
514
+ start_epoch = to_restore["epoch"]
515
+
516
+ ##======== Starting Mugs training ============
517
+ logger.info("Starting Mugs training !")
518
+ start_time = time.time()
519
+ for epoch in range(start_epoch, args.epochs):
520
+ t1 = time.time()
521
+ data_loader.sampler.set_epoch(epoch)
522
+
523
+ ##======== training one epoch of Mugs ============
524
+ train_stats = train_one_epoch(
525
+ student,
526
+ teacher,
527
+ teacher_without_ddp,
528
+ all_losses,
529
+ all_weights,
530
+ data_loader,
531
+ optimizer,
532
+ lr_schedule,
533
+ wd_schedule,
534
+ momentum_schedule,
535
+ epoch,
536
+ fp16_scaler,
537
+ student_mem,
538
+ teacher_mem,
539
+ logger,
540
+ args,
541
+ )
542
+
543
+ ##======== save model checkpoint ============
544
+ save_dict = {
545
+ "student": student.state_dict(),
546
+ "teacher": teacher.state_dict(),
547
+ "student_mem": student_mem.state_dict()
548
+ if student_mem is not None
549
+ else None,
550
+ "teacher_mem": teacher_mem.state_dict()
551
+ if teacher_mem is not None
552
+ else None,
553
+ "optimizer": optimizer.state_dict(),
554
+ "epoch": epoch + 1,
555
+ "args": args,
556
+ }
557
+ granular_loss_dicts = {}
558
+ for name, loss in all_losses.items():
559
+ granular_loss_dicts[name] = loss.state_dict()
560
+ save_dict.update(granular_loss_dicts)
561
+
562
+ if fp16_scaler is not None:
563
+ save_dict["fp16_scaler"] = fp16_scaler.state_dict()
564
+
565
+ utils.save_on_master(save_dict, os.path.join(args.output_dir, "checkpoint.pth"))
566
+ if args.saveckp_freq and epoch % args.saveckp_freq == 0:
567
+ utils.save_on_master(
568
+ save_dict, os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth")
569
+ )
570
+
571
+ ##======== writing logs ============
572
+ log_stats = {**{f"{k}": v for k, v in train_stats.items()}, "epoch": epoch}
573
+ if utils.is_main_process():
574
+ with (Path(args.output_dir) / "log.txt").open("a") as f:
575
+ f.write(json.dumps(log_stats) + "\n")
576
+
577
+ t2 = time.time()
578
+ log_results = ""
579
+ for k, v in train_stats.items():
580
+ log_results += "%s: %.6f, " % (k, v)
581
+ logger.info(
582
+ "%d-epoch: %s remaining time %.2f hours"
583
+ % (epoch, log_results, (t2 - t1) * (args.epochs - epoch) / 3600.0)
584
+ )
585
+
586
+ total_time = time.time() - start_time
587
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
588
+ logger.info("Training time {}".format(total_time_str))
589
+
590
+
591
+ def train_one_epoch(
592
+ student,
593
+ teacher,
594
+ teacher_without_ddp,
595
+ all_losses,
596
+ all_weights,
597
+ data_loader,
598
+ optimizer,
599
+ lr_schedule,
600
+ wd_schedule,
601
+ momentum_schedule,
602
+ epoch,
603
+ fp16_scaler,
604
+ student_mem,
605
+ teacher_mem,
606
+ logger,
607
+ args,
608
+ ):
609
+ """
610
+ main training code for each epoch
611
+ """
612
+ metric_logger = utils.MetricLogger(delimiter=" ")
613
+ prefetcher = data_prefetcher(data_loader, fp16=(fp16_scaler is not None))
614
+ images, weak_aug_flags = prefetcher.next()
615
+ epoch_it = 0
616
+ while images is not None:
617
+ # Step 1. update weight decay and learning rate according to their schedule
618
+ it = len(data_loader) * epoch + epoch_it # global training iteration
619
+ for _, param_group in enumerate(optimizer.param_groups):
620
+ lr_mult = 1.0
621
+ if "patch_embed" in param_group["name"]:
622
+ lr_mult = args.patch_embed_lr_mult
623
+ param_group["lr"] = lr_schedule[it] * lr_mult
624
+ if param_group.get("apply_wd", True): # only the first group is regularized
625
+ param_group["weight_decay"] = wd_schedule[it]
626
+
627
+ granular_losses = OrderedDict()
628
+ total_loss = 0
629
+ with torch.cuda.amp.autocast(fp16_scaler is not None):
630
+ ## Step 2. forward images into teacher and student to obtain the
631
+ # features/superivisons for the three granular superivison losses
632
+ (
633
+ teacher_instance_target,
634
+ teacher_local_group_target,
635
+ teacher_group_target,
636
+ teacher_memory_tokens,
637
+ ) = teacher(
638
+ images[:2],
639
+ return_target=True,
640
+ local_group_memory_inputs={"mem": teacher_mem},
641
+ )
642
+
643
+ (
644
+ student_instance_target,
645
+ student_local_group_target,
646
+ student_group_target,
647
+ student_memory_tokens,
648
+ ) = student(
649
+ images[2:],
650
+ return_target=False,
651
+ local_group_memory_inputs={"mem": student_mem},
652
+ )
653
+
654
+ ## Step 3. compute the three granular supervision losses, including instance,
655
+ # local-group, group supervision losses
656
+ weigts_sum, total_loss, granular_losses = 0.0, 0.0, OrderedDict()
657
+ # instance loss
658
+ loss_cls, loss_weight = (
659
+ all_losses["instance-sup."],
660
+ all_weights["instance-sup."],
661
+ )
662
+ if loss_weight > 0:
663
+ instance_loss = loss_cls(
664
+ student_instance_target, teacher_instance_target, epoch
665
+ )
666
+ weigts_sum, total_loss = (
667
+ weigts_sum + loss_weight,
668
+ total_loss + instance_loss,
669
+ )
670
+ granular_losses["instance-sup."] = instance_loss.item()
671
+
672
+ # local group loss
673
+ loss_cls, loss_weight = (
674
+ all_losses["local-group-sup."],
675
+ all_weights["local-group-sup."],
676
+ )
677
+ if loss_weight > 0:
678
+ local_group_loss = loss_cls(
679
+ student_local_group_target, teacher_local_group_target, epoch
680
+ )
681
+ weigts_sum, total_loss = (
682
+ weigts_sum + loss_weight,
683
+ total_loss + local_group_loss,
684
+ )
685
+ granular_losses["local-group-sup."] = local_group_loss.item()
686
+
687
+ # group loss
688
+ loss_cls, loss_weight = all_losses["group-sup."], all_weights["group-sup."]
689
+ if loss_weight > 0:
690
+ group_loss = loss_cls(student_group_target, teacher_group_target, epoch)
691
+ weigts_sum, total_loss = (
692
+ weigts_sum + loss_weight,
693
+ total_loss + group_loss,
694
+ )
695
+ granular_losses["group-sup."] = group_loss.item()
696
+
697
+ # average loss
698
+ total_loss /= weigts_sum
699
+
700
+ ## ## Step 4. update the memory buffer for local-group supervision losses.
701
+ # for student, we only update memory by the image of size 224 and weak augmentations
702
+ student_features = (student_memory_tokens.chunk(2))[0]
703
+ len_weak = student_mem._dequeue_and_enqueue(
704
+ student_features,
705
+ weak_aug_flags,
706
+ )
707
+
708
+ teacher_weak = (teacher_memory_tokens.chunk(2))[0]
709
+ _ = teacher_mem._dequeue_and_enqueue(teacher_weak, None)
710
+
711
+ if not math.isfinite(total_loss.item()):
712
+ print("Loss is {}, stopping training".format(total_loss.item()), force=True)
713
+ sys.exit(1)
714
+
715
+ ## Step 5. student and teacher update
716
+ # student update
717
+ optimizer.zero_grad()
718
+ if fp16_scaler is None:
719
+ total_loss.backward()
720
+ if args.clip_grad:
721
+ clip_grad = args.clip_grad
722
+ if epoch > 100 and args.arch == "vit_large":
723
+ clip_grad = args.clip_grad / 10.0
724
+ _ = clip_gradients(student, clip_grad)
725
+ cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
726
+ optimizer.step()
727
+ else:
728
+ fp16_scaler.scale(total_loss).backward()
729
+ if args.clip_grad:
730
+ clip_grad = args.clip_grad
731
+ if epoch > 100 and args.arch == "vit_large":
732
+ clip_grad = args.clip_grad /10.0
733
+ fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
734
+ _ = clip_gradients(student, clip_grad)
735
+ cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
736
+ fp16_scaler.step(optimizer)
737
+ fp16_scaler.update()
738
+
739
+ # EMA update for the teacher
740
+ with torch.no_grad():
741
+ m = momentum_schedule[it] # momentum parameter
742
+ for param_q, param_k in zip(
743
+ student.module.backbone.parameters(),
744
+ teacher_without_ddp.backbone.parameters(),
745
+ ):
746
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
747
+
748
+ if teacher_without_ddp.instance_head is not None:
749
+ for param_q, param_k in zip(
750
+ student.module.instance_head.parameters(),
751
+ teacher_without_ddp.instance_head.parameters(),
752
+ ):
753
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
754
+
755
+ if teacher_without_ddp.local_group_head is not None:
756
+ for param_q, param_k in zip(
757
+ student.module.local_group_head.parameters(),
758
+ teacher_without_ddp.local_group_head.parameters(),
759
+ ):
760
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
761
+
762
+ if teacher_without_ddp.group_head is not None:
763
+ for param_q, param_k in zip(
764
+ student.module.group_head.parameters(),
765
+ teacher_without_ddp.group_head.parameters(),
766
+ ):
767
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
768
+
769
+ ## Step 6. load images
770
+ images, weak_aug_flags = prefetcher.next()
771
+ epoch_it += 1
772
+
773
+ ## Step 7. logging
774
+ torch.cuda.synchronize()
775
+ metric_logger.update(loss=total_loss.item())
776
+ for loss_name, loss_value in granular_losses.items():
777
+ metric_logger.update(**{loss_name: loss_value})
778
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
779
+ metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
780
+
781
+ if epoch_it % 500 == 0 and args.rank == 0: # and epoch_it < 10:
782
+ log_results = ""
783
+ for _, loss_name in enumerate(all_losses):
784
+ if all_weights[loss_name] > 0:
785
+ log_results += "%s: %.6f," % (
786
+ loss_name,
787
+ metric_logger.meters[loss_name].global_avg,
788
+ )
789
+ logger.info(
790
+ "%d-epoch (%d/%d): total loss %.6f, %s, lr %.4e, wd %.4e, weak aug. ratio %.1f"
791
+ % (
792
+ epoch,
793
+ it,
794
+ len(data_loader),
795
+ metric_logger.meters["loss"].global_avg,
796
+ log_results,
797
+ optimizer.param_groups[0]["lr"],
798
+ optimizer.param_groups[0]["weight_decay"],
799
+ len_weak / len(weak_aug_flags) / args.world_size,
800
+ )
801
+ )
802
+
803
+ # gather the stats from all processes
804
+ metric_logger.synchronize_between_processes()
805
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
806
+
807
+
808
+ if __name__ == "__main__":
809
+ parser = argparse.ArgumentParser("Mugs", parents=[get_args_parser()])
810
+ args = parser.parse_args()
811
+ if not os.path.exists(args.output_dir):
812
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
813
+
814
+ train_mugs(args)
pretraining.sh ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ DATASET_ROOT=/dataset/imageNet100_sicy/train/ #/raid/common/imagenet-raw/
3
+
4
+ ## train ViT-small for 100 epochs
5
+ OUTPUT_ROOT=./exps/vit_small_100ep
6
+ NPROC_PER_NODE=8 # GPU numbers
7
+ BATCH_SIZE_PER_GPU=64
8
+ DEBUG=false # debug = true, then we only load subset of the whole training dataset
9
+ python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
10
+ --data_path $DATASET_ROOT \
11
+ --output_dir $OUTPUT_ROOT \
12
+ --arch vit_small \
13
+ --instance_queue_size 65536 \
14
+ --local_group_queue_size 65536 \
15
+ --use_bn_in_head false \
16
+ --instance_out_dim 256 \
17
+ --instance_temp 0.2 \
18
+ --local_group_out_dim 256 \
19
+ --local_group_temp 0.2 \
20
+ --local_group_knn_top_n 8 \
21
+ --group_out_dim 65536 \
22
+ --group_student_temp 0.1 \
23
+ --group_warmup_teacher_temp 0.04 \
24
+ --group_teacher_temp 0.04 \
25
+ --group_warmup_teacher_temp_epochs 0 \
26
+ --norm_last_layer false \
27
+ --norm_before_pred true \
28
+ --batch_size_per_gpu $BATCH_SIZE_PER_GPU \
29
+ --epochs 100 \
30
+ --warmup_epochs 10 \
31
+ --clip_grad 3.0 \
32
+ --lr 0.0008 \
33
+ --min_lr 1e-06 \
34
+ --patch_embed_lr_mult 0.2 \
35
+ --drop_path_rate 0.1 \
36
+ --weight_decay 0.04 \
37
+ --weight_decay_end 0.2 \
38
+ --freeze_last_layer 1 \
39
+ --momentum_teacher 0.996 \
40
+ --use_fp16 false \
41
+ --local_crops_number 10 \
42
+ --size_crops 96 \
43
+ --global_crops_scale 0.25 1 \
44
+ --local_crops_scale 0.05 0.25 \
45
+ --timm_auto_augment_par rand-m9-mstd0.5-inc1 \
46
+ --prob 0.5 \
47
+ --use_prefetcher true \
48
+ --debug $DEBUG
49
+
50
+ ## train ViT-small for 300 epochs
51
+ OUTPUT_ROOT=./exps/vit_small_300ep
52
+ NPROC_PER_NODE=16 # GPU numbers
53
+ BATCH_SIZE_PER_GPU=64
54
+ DEBUG=false # debug = true, then we only load subset of the whole training dataset
55
+ python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
56
+ --data_path $DATASET_ROOT \
57
+ --output_dir $OUTPUT_ROOT \
58
+ --arch vit_small \
59
+ --instance_queue_size 65536 \
60
+ --local_group_queue_size 65536 \
61
+ --use_bn_in_head false \
62
+ --instance_out_dim 256 \
63
+ --instance_temp 0.2 \
64
+ --local_group_out_dim 256 \
65
+ --local_group_temp 0.2 \
66
+ --local_group_knn_top_n 8 \
67
+ --group_out_dim 65536 \
68
+ --group_student_temp 0.1 \
69
+ --group_warmup_teacher_temp 0.04 \
70
+ --group_teacher_temp 0.07 \
71
+ --group_warmup_teacher_temp_epochs 30 \
72
+ --norm_last_layer false \
73
+ --norm_before_pred true \
74
+ --batch_size_per_gpu $BATCH_SIZE_PER_GPU \
75
+ --epochs 300 \
76
+ --warmup_epochs 10 \
77
+ --clip_grad 3.0 \
78
+ --lr 0.0008 \
79
+ --min_lr 1e-06 \
80
+ --patch_embed_lr_mult 0.2 \
81
+ --drop_path_rate 0.1 \
82
+ --weight_decay 0.04 \
83
+ --weight_decay_end 0.1 \
84
+ --freeze_last_layer 1 \
85
+ --momentum_teacher 0.996 \
86
+ --use_fp16 false \
87
+ --local_crops_number 10 \
88
+ --size_crops 96 \
89
+ --global_crops_scale 0.25 1 \
90
+ --local_crops_scale 0.05 0.25 \
91
+ --timm_auto_augment_par rand-m9-mstd0.5-inc1 \
92
+ --prob 0.5 \
93
+ --use_prefetcher true \
94
+ --debug $DEBUG
95
+
96
+ ## train ViT-small for 800 epochs
97
+ NPROC_PER_NODE=16 # GPU numbers
98
+ BATCH_SIZE_PER_GPU=64
99
+ DEBUG=false # debug = true, then we only load subset of the whole training dataset
100
+ python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
101
+ --data_path $DATASET_ROOT \
102
+ --output_dir $OUTPUT_ROOT \
103
+ --arch vit_small \
104
+ --instance_queue_size 65536 \
105
+ --local_group_queue_size 65536 \
106
+ --use_bn_in_head false \
107
+ --instance_out_dim 256 \
108
+ --instance_temp 0.2 \
109
+ --local_group_out_dim 256 \
110
+ --local_group_temp 0.2 \
111
+ --local_group_knn_top_n 8 \
112
+ --group_out_dim 65536 \
113
+ --group_student_temp 0.1 \
114
+ --group_warmup_teacher_temp 0.04 \
115
+ --group_teacher_temp 0.07 \
116
+ --group_warmup_teacher_temp_epochs 30 \
117
+ --norm_last_layer false \
118
+ --norm_before_pred true \
119
+ --batch_size_per_gpu $BATCH_SIZE_PER_GPU \
120
+ --epochs 800 \
121
+ --warmup_epochs 10 \
122
+ --clip_grad 3.0 \
123
+ --lr 0.0008 \
124
+ --min_lr 1e-06 \
125
+ --patch_embed_lr_mult 0.2 \
126
+ --drop_path_rate 0.1 \
127
+ --weight_decay 0.04 \
128
+ --weight_decay_end 0.1 \
129
+ --freeze_last_layer 1 \
130
+ --momentum_teacher 0.996 \
131
+ --use_fp16 false \
132
+ --local_crops_number 10 \
133
+ --size_crops 96 \
134
+ --global_crops_scale 0.25 1 \
135
+ --local_crops_scale 0.05 0.25 \
136
+ --timm_auto_augment_par rand-m9-mstd0.5-inc1 \
137
+ --prob 0.5 \
138
+ --use_prefetcher true \
139
+ --debug $DEBUG
140
+
141
+ ## train ViT-base for 400 epochs
142
+ OUTPUT_ROOT=./exps/vit_base_400ep
143
+ NPROC_PER_NODE=24 # GPU numbers
144
+ BATCH_SIZE_PER_GPU=42
145
+ DEBUG=false # debug = true, then we only load subset of the whole training dataset
146
+ python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
147
+ --data_path $DATASET_ROOT \
148
+ --output_dir $OUTPUT_ROOT \
149
+ --arch vit_base \
150
+ --instance_queue_size 65536 \
151
+ --local_group_queue_size 65536 \
152
+ --use_bn_in_head false \
153
+ --instance_out_dim 256 \
154
+ --instance_temp 0.2 \
155
+ --local_group_out_dim 256 \
156
+ --local_group_temp 0.2 \
157
+ --local_group_knn_top_n 8 \
158
+ --group_out_dim 65536 \
159
+ --group_student_temp 0.1 \
160
+ --group_warmup_teacher_temp 0.04 \
161
+ --group_teacher_temp 0.07 \
162
+ --group_warmup_teacher_temp_epochs 50 \
163
+ --norm_last_layer false \
164
+ --norm_before_pred true \
165
+ --batch_size_per_gpu $BATCH_SIZE_PER_GPU \
166
+ --epochs 400 \
167
+ --warmup_epochs 10 \
168
+ --clip_grad 3.0 \
169
+ --lr 0.0008 \
170
+ --min_lr 2e-06 \
171
+ --patch_embed_lr_mult 0.2 \
172
+ --drop_path_rate 0.1 \
173
+ --weight_decay 0.04 \
174
+ --weight_decay_end 0.1 \
175
+ --freeze_last_layer 3 \
176
+ --momentum_teacher 0.996 \
177
+ --use_fp16 false \
178
+ --local_crops_number 10 \
179
+ --size_crops 96 \
180
+ --global_crops_scale 0.25 1 \
181
+ --local_crops_scale 0.05 0.25 \
182
+ --timm_auto_augment_par rand-m9-mstd0.5-inc1 \
183
+ --prob 0.5 \
184
+ --use_prefetcher true \
185
+ --debug $DEBUG
186
+
187
+ ## train ViT-large for 250 epochs
188
+ OUTPUT_ROOT=./exps/vit_large_250ep
189
+ NPROC_PER_NODE=40 # GPU numbers
190
+ BATCH_SIZE_PER_GPU=16
191
+ DEBUG=false # debug = true, then we only load subset of the whole training dataset
192
+ python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
193
+ --data_path $DATASET_ROOT \
194
+ --output_dir $OUTPUT_ROOT \
195
+ --arch vit_large \
196
+ --instance_queue_size 65536 \
197
+ --local_group_queue_size 65536 \
198
+ --use_bn_in_head false \
199
+ --instance_out_dim 256 \
200
+ --instance_temp 0.2 \
201
+ --local_group_out_dim 256 \
202
+ --local_group_temp 0.2 \
203
+ --local_group_knn_top_n 8 \
204
+ --group_out_dim 65536 \
205
+ --group_student_temp 0.1 \
206
+ --group_warmup_teacher_temp 0.04 \
207
+ --group_teacher_temp 0.07 \
208
+ --group_warmup_teacher_temp_epochs 50 \
209
+ --norm_last_layer true \
210
+ --norm_before_pred true \
211
+ --batch_size_per_gpu $BATCH_SIZE_PER_GPU \
212
+ --epochs 250 \
213
+ --warmup_epochs 10 \
214
+ --clip_grad 3.0 \
215
+ --lr 0.0015 \
216
+ --min_lr 1.5e-4 \
217
+ --patch_embed_lr_mult 0.2 \
218
+ --drop_path_rate 0.3 \
219
+ --weight_decay 0.025 \
220
+ --weight_decay_end 0.08 \
221
+ --freeze_last_layer 3 \
222
+ --momentum_teacher 0.996 \
223
+ --use_fp16 false \
224
+ --local_crops_number 10 \
225
+ --size_crops 96 \
226
+ --global_crops_scale 0.25 1 \
227
+ --local_crops_scale 0.05 0.25 \
228
+ --timm_auto_augment_par rand-m9-mstd0.5-inc1 \
229
+ --prob 0.5 \
230
+ --use_prefetcher true \
231
+ --debug $DEBUG
src/RandAugment.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ implment AutoAugment, RandAugment
16
+ Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py and modified for token labeling
17
+ """
18
+ import math
19
+ import random
20
+ import re
21
+
22
+ import numpy as np
23
+ import PIL
24
+ from PIL import Image, ImageEnhance, ImageOps
25
+
26
+ _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
27
+
28
+ _FILL = (128, 128, 128)
29
+
30
+ _MAX_LEVEL = 10.0
31
+
32
+ _HPARAMS_DEFAULT = dict(
33
+ translate_const=250,
34
+ img_mean=_FILL,
35
+ )
36
+
37
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
38
+
39
+
40
+ def _interpolation(kwargs):
41
+ interpolation = kwargs.pop("resample", Image.BILINEAR)
42
+ if isinstance(interpolation, (list, tuple)):
43
+ return random.choice(interpolation)
44
+ else:
45
+ return interpolation
46
+
47
+
48
+ def _check_args_tf(kwargs):
49
+ if "fillcolor" in kwargs and _PIL_VER < (5, 0):
50
+ kwargs.pop("fillcolor")
51
+ kwargs["resample"] = _interpolation(kwargs)
52
+
53
+
54
+ def shear_x(img, factor, **kwargs):
55
+ _check_args_tf(kwargs)
56
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
57
+
58
+
59
+ def shear_y(img, factor, **kwargs):
60
+ _check_args_tf(kwargs)
61
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
62
+
63
+
64
+ def translate_x_rel(img, pct, **kwargs):
65
+ pixels = pct * img.size[0]
66
+ _check_args_tf(kwargs)
67
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
68
+
69
+
70
+ def translate_y_rel(img, pct, **kwargs):
71
+ pixels = pct * img.size[1]
72
+ _check_args_tf(kwargs)
73
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
74
+
75
+
76
+ def translate_x_abs(img, pixels, **kwargs):
77
+ _check_args_tf(kwargs)
78
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
79
+
80
+
81
+ def translate_y_abs(img, pixels, **kwargs):
82
+ _check_args_tf(kwargs)
83
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
84
+
85
+
86
+ def rotate(img, degrees, **kwargs):
87
+ _check_args_tf(kwargs)
88
+ if _PIL_VER >= (5, 2):
89
+ return img.rotate(degrees, **kwargs)
90
+ elif _PIL_VER >= (5, 0):
91
+ w, h = img.size
92
+ post_trans = (0, 0)
93
+ rotn_center = (w / 2.0, h / 2.0)
94
+ angle = -math.radians(degrees)
95
+ matrix = [
96
+ round(math.cos(angle), 15),
97
+ round(math.sin(angle), 15),
98
+ 0.0,
99
+ round(-math.sin(angle), 15),
100
+ round(math.cos(angle), 15),
101
+ 0.0,
102
+ ]
103
+
104
+ def transform(x, y, matrix):
105
+ (a, b, c, d, e, f) = matrix
106
+ return a * x + b * y + c, d * x + e * y + f
107
+
108
+ matrix[2], matrix[5] = transform(
109
+ -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
110
+ )
111
+ matrix[2] += rotn_center[0]
112
+ matrix[5] += rotn_center[1]
113
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
114
+ else:
115
+ return img.rotate(degrees, resample=kwargs["resample"])
116
+
117
+
118
+ def auto_contrast(img, **__):
119
+ return ImageOps.autocontrast(img)
120
+
121
+
122
+ def invert(img, **__):
123
+ return ImageOps.invert(img)
124
+
125
+
126
+ def equalize(img, **__):
127
+ return ImageOps.equalize(img)
128
+
129
+
130
+ def solarize(img, thresh, **__):
131
+ return ImageOps.solarize(img, thresh)
132
+
133
+
134
+ def solarize_add(img, add, thresh=128, **__):
135
+ lut = []
136
+ for i in range(256):
137
+ if i < thresh:
138
+ lut.append(min(255, i + add))
139
+ else:
140
+ lut.append(i)
141
+ if img.mode in ("L", "RGB"):
142
+ if img.mode == "RGB" and len(lut) == 256:
143
+ lut = lut + lut + lut
144
+ return img.point(lut)
145
+ else:
146
+ return img
147
+
148
+
149
+ def posterize(img, bits_to_keep, **__):
150
+ if bits_to_keep >= 8:
151
+ return img
152
+ return ImageOps.posterize(img, bits_to_keep)
153
+
154
+
155
+ def contrast(img, factor, **__):
156
+ return ImageEnhance.Contrast(img).enhance(factor)
157
+
158
+
159
+ def color(img, factor, **__):
160
+ return ImageEnhance.Color(img).enhance(factor)
161
+
162
+
163
+ def brightness(img, factor, **__):
164
+ return ImageEnhance.Brightness(img).enhance(factor)
165
+
166
+
167
+ def sharpness(img, factor, **__):
168
+ return ImageEnhance.Sharpness(img).enhance(factor)
169
+
170
+
171
+ def _randomly_negate(v):
172
+ """With 50% prob, negate the value"""
173
+ return -v if random.random() > 0.5 else v
174
+
175
+
176
+ def _rotate_level_to_arg(level, _hparams):
177
+ # range [-30, 30]
178
+ level = (level / _MAX_LEVEL) * 30.0
179
+ level = _randomly_negate(level)
180
+ return (level,)
181
+
182
+
183
+ def _enhance_level_to_arg(level, _hparams):
184
+ # range [0.1, 1.9]
185
+ return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
186
+
187
+
188
+ def _enhance_increasing_level_to_arg(level, _hparams):
189
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
190
+ # range [0.1, 1.9]
191
+ level = (level / _MAX_LEVEL) * 0.9
192
+ level = 1.0 + _randomly_negate(level)
193
+ return (level,)
194
+
195
+
196
+ def _shear_level_to_arg(level, _hparams):
197
+ # range [-0.3, 0.3]
198
+ level = (level / _MAX_LEVEL) * 0.3
199
+ level = _randomly_negate(level)
200
+ return (level,)
201
+
202
+
203
+ def _translate_abs_level_to_arg(level, hparams):
204
+ translate_const = hparams["translate_const"]
205
+ level = (level / _MAX_LEVEL) * float(translate_const)
206
+ level = _randomly_negate(level)
207
+ return (level,)
208
+
209
+
210
+ def _translate_rel_level_to_arg(level, hparams):
211
+ # default range [-0.45, 0.45]
212
+ translate_pct = hparams.get("translate_pct", 0.45)
213
+ level = (level / _MAX_LEVEL) * translate_pct
214
+ level = _randomly_negate(level)
215
+ return (level,)
216
+
217
+
218
+ def _posterize_level_to_arg(level, _hparams):
219
+ # As per Tensorflow TPU EfficientNet impl
220
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
221
+ # intensity/severity of augmentation decreases with level
222
+ return (int((level / _MAX_LEVEL) * 4),)
223
+
224
+
225
+ def _posterize_increasing_level_to_arg(level, hparams):
226
+ # As per Tensorflow models research and UDA impl
227
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
228
+ # intensity/severity of augmentation increases with level
229
+ return (4 - _posterize_level_to_arg(level, hparams)[0],)
230
+
231
+
232
+ def _posterize_original_level_to_arg(level, _hparams):
233
+ # As per original AutoAugment paper description
234
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
235
+ # intensity/severity of augmentation decreases with level
236
+ return (int((level / _MAX_LEVEL) * 4) + 4,)
237
+
238
+
239
+ def _solarize_level_to_arg(level, _hparams):
240
+ # range [0, 256]
241
+ # intensity/severity of augmentation decreases with level
242
+ return (int((level / _MAX_LEVEL) * 256),)
243
+
244
+
245
+ def _solarize_increasing_level_to_arg(level, _hparams):
246
+ # range [0, 256]
247
+ # intensity/severity of augmentation increases with level
248
+ return (256 - _solarize_level_to_arg(level, _hparams)[0],)
249
+
250
+
251
+ def _solarize_add_level_to_arg(level, _hparams):
252
+ # range [0, 110]
253
+ return (int((level / _MAX_LEVEL) * 110),)
254
+
255
+
256
+ LEVEL_TO_ARG = {
257
+ "AutoContrast": None,
258
+ "Equalize": None,
259
+ "Invert": None,
260
+ "Rotate": _rotate_level_to_arg,
261
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
262
+ "Posterize": _posterize_level_to_arg,
263
+ "PosterizeIncreasing": _posterize_increasing_level_to_arg,
264
+ "PosterizeOriginal": _posterize_original_level_to_arg,
265
+ "Solarize": _solarize_level_to_arg,
266
+ "SolarizeIncreasing": _solarize_increasing_level_to_arg,
267
+ "SolarizeAdd": _solarize_add_level_to_arg,
268
+ "Color": _enhance_level_to_arg,
269
+ "ColorIncreasing": _enhance_increasing_level_to_arg,
270
+ "Contrast": _enhance_level_to_arg,
271
+ "ContrastIncreasing": _enhance_increasing_level_to_arg,
272
+ "Brightness": _enhance_level_to_arg,
273
+ "BrightnessIncreasing": _enhance_increasing_level_to_arg,
274
+ "Sharpness": _enhance_level_to_arg,
275
+ "SharpnessIncreasing": _enhance_increasing_level_to_arg,
276
+ "ShearX": _shear_level_to_arg,
277
+ "ShearY": _shear_level_to_arg,
278
+ "TranslateX": _translate_abs_level_to_arg,
279
+ "TranslateY": _translate_abs_level_to_arg,
280
+ "TranslateXRel": _translate_rel_level_to_arg,
281
+ "TranslateYRel": _translate_rel_level_to_arg,
282
+ }
283
+
284
+
285
+ NAME_TO_OP = {
286
+ "AutoContrast": auto_contrast,
287
+ "Equalize": equalize,
288
+ "Invert": invert,
289
+ "Rotate": rotate,
290
+ "Posterize": posterize,
291
+ "PosterizeIncreasing": posterize,
292
+ "PosterizeOriginal": posterize,
293
+ "Solarize": solarize,
294
+ "SolarizeIncreasing": solarize,
295
+ "SolarizeAdd": solarize_add,
296
+ "Color": color,
297
+ "ColorIncreasing": color,
298
+ "Contrast": contrast,
299
+ "ContrastIncreasing": contrast,
300
+ "Brightness": brightness,
301
+ "BrightnessIncreasing": brightness,
302
+ "Sharpness": sharpness,
303
+ "SharpnessIncreasing": sharpness,
304
+ "ShearX": shear_x,
305
+ "ShearY": shear_y,
306
+ "TranslateX": translate_x_abs,
307
+ "TranslateY": translate_y_abs,
308
+ "TranslateXRel": translate_x_rel,
309
+ "TranslateYRel": translate_y_rel,
310
+ }
311
+
312
+ _RAND_TRANSFORMS = [
313
+ "AutoContrast",
314
+ "Equalize",
315
+ "Invert",
316
+ "Rotate",
317
+ "Posterize",
318
+ "Solarize",
319
+ "SolarizeAdd",
320
+ "Color",
321
+ "Contrast",
322
+ "Brightness",
323
+ "Sharpness",
324
+ "ShearX",
325
+ "ShearY",
326
+ "TranslateXRel",
327
+ "TranslateYRel",
328
+ #'Cutout'
329
+ ]
330
+
331
+
332
+ _RAND_INCREASING_TRANSFORMS = [
333
+ "AutoContrast",
334
+ "Equalize",
335
+ "Invert",
336
+ "Rotate",
337
+ "PosterizeIncreasing",
338
+ "SolarizeIncreasing",
339
+ "SolarizeAdd",
340
+ "ColorIncreasing",
341
+ "ContrastIncreasing",
342
+ "BrightnessIncreasing",
343
+ "SharpnessIncreasing",
344
+ "ShearX",
345
+ "ShearY",
346
+ "TranslateXRel",
347
+ "TranslateYRel",
348
+ #'Cutout'
349
+ ]
350
+
351
+
352
+ # These experimental weights are based loosely on the relative improvements mentioned in paper.
353
+ # They may not result in increased performance, but could likely be tuned to so.
354
+ _RAND_CHOICE_WEIGHTS_0 = {
355
+ "Rotate": 0.3,
356
+ "ShearX": 0.2,
357
+ "ShearY": 0.2,
358
+ "TranslateXRel": 0.1,
359
+ "TranslateYRel": 0.1,
360
+ "Color": 0.025,
361
+ "Sharpness": 0.025,
362
+ "AutoContrast": 0.025,
363
+ "Solarize": 0.005,
364
+ "SolarizeAdd": 0.005,
365
+ "Contrast": 0.005,
366
+ "Brightness": 0.005,
367
+ "Equalize": 0.005,
368
+ "Posterize": 0,
369
+ "Invert": 0,
370
+ }
371
+
372
+
373
+ def _select_rand_weights(weight_idx=0, transforms=None):
374
+ transforms = transforms or _RAND_TRANSFORMS
375
+ assert weight_idx == 0 # only one set of weights currently
376
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
377
+ probs = [rand_weights[k] for k in transforms]
378
+ probs /= np.sum(probs)
379
+ return probs
380
+
381
+
382
+ class AugmentOp:
383
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
384
+ hparams = hparams or _HPARAMS_DEFAULT
385
+ self.name = name
386
+ self.aug_fn = NAME_TO_OP[name]
387
+ self.level_fn = LEVEL_TO_ARG[name]
388
+ self.prob = prob
389
+ self.magnitude = magnitude
390
+ self.hparams = hparams.copy()
391
+ self.kwargs = dict(
392
+ fillcolor=hparams["img_mean"] if "img_mean" in hparams else _FILL,
393
+ resample=hparams["interpolation"]
394
+ if "interpolation" in hparams
395
+ else _RANDOM_INTERPOLATION,
396
+ )
397
+
398
+ # If magnitude_std is > 0, we introduce some randomness
399
+ # in the usually fixed policy and sample magnitude from a normal distribution
400
+ # with mean `magnitude` and std-dev of `magnitude_std`.
401
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
402
+ self.magnitude_std = self.hparams.get("magnitude_std", 0)
403
+
404
+ def __call__(self, img):
405
+ if self.prob < 1.0 and random.random() > self.prob:
406
+ return img
407
+ magnitude = self.magnitude
408
+ if self.magnitude_std and self.magnitude_std > 0:
409
+ magnitude = random.gauss(magnitude, self.magnitude_std)
410
+ magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
411
+ level_args = (
412
+ self.level_fn(magnitude, self.hparams)
413
+ if self.level_fn is not None
414
+ else tuple()
415
+ )
416
+ imgs = self.aug_fn(img, *level_args, **self.kwargs)
417
+
418
+ return imgs
419
+
420
+
421
+ def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
422
+ hparams = hparams or _HPARAMS_DEFAULT
423
+ transforms = transforms or _RAND_TRANSFORMS
424
+ return [
425
+ AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
426
+ for name in transforms
427
+ ]
428
+
429
+
430
+ class RandAugment:
431
+ """
432
+ Apply RandAug on image
433
+ """
434
+
435
+ def __init__(self, ops, num_layers=2, choice_weights=None):
436
+ self.ops = ops
437
+ self.num_layers = num_layers
438
+ self.choice_weights = choice_weights
439
+
440
+ def __call__(self, img):
441
+ # no replacement when using weighted choice
442
+ ops = np.random.choice(
443
+ self.ops, self.num_layers, replace=False, p=self.choice_weights
444
+ )
445
+ for op in ops:
446
+ img = op(img)
447
+
448
+ return img
449
+
450
+
451
+ def rand_augment_transform(config_str, hparams):
452
+ """
453
+ Create a RandAugment transform
454
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
455
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
456
+ sections, not order sepecific determine
457
+ 'm' - integer magnitude of rand augment
458
+ 'n' - integer num layers (number of transform ops selected per image)
459
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
460
+ 'mstd' - float std deviation of magnitude noise applied
461
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
462
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
463
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
464
+
465
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
466
+
467
+ :return: A PyTorch compatible Transform
468
+ """
469
+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
470
+ num_layers = 2 # default to 2 ops per image
471
+ weight_idx = None # default to no probability weights for op choice
472
+ transforms = _RAND_TRANSFORMS
473
+ config = config_str.split("-")
474
+ assert config[0] == "rand"
475
+ config = config[1:]
476
+ for c in config:
477
+ cs = re.split(r"(\d.*)", c)
478
+ if len(cs) < 2:
479
+ continue
480
+ key, val = cs[:2]
481
+ if key == "mstd":
482
+ # noise param injected via hparams for now
483
+ hparams.setdefault("magnitude_std", float(val))
484
+ elif key == "inc":
485
+ if bool(val): # this path
486
+ transforms = _RAND_INCREASING_TRANSFORMS
487
+ elif key == "m":
488
+ magnitude = int(val)
489
+ elif key == "n":
490
+ num_layers = int(val)
491
+ elif key == "w":
492
+ weight_idx = int(val)
493
+ else:
494
+ assert False, "Unknown RandAugment config section"
495
+ # magnitude 9
496
+ # hparams {'translate_const': 100, 'img_mean': (124, 116, 104), 'magnitude_std': 0.5}
497
+ # transforms ['AutoContrast', 'Equalize', 'Invert', 'Rotate', 'PosterizeIncreasing', \
498
+ # 'SolarizeIncreasing', 'SolarizeAdd', 'ColorIncreasing', 'ContrastIncreasing', \
499
+ # 'BrightnessIncreasing', 'SharpnessIncreasing', 'ShearX', 'ShearY', 'TranslateXRel', 'TranslateYRel']
500
+ ra_ops = rand_augment_ops(
501
+ magnitude=magnitude, hparams=hparams, transforms=transforms
502
+ )
503
+ choice_weights = (
504
+ None if weight_idx is None else _select_rand_weights(weight_idx)
505
+ ) ## None
506
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
src/dataset.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ ImageFold function.
16
+
17
+ Mostly copy-paste from torchvision references
18
+ """
19
+ import os
20
+ import os.path
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, cast
22
+
23
+ from PIL import Image
24
+ from torchvision.datasets.vision import VisionDataset
25
+
26
+
27
+ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
28
+ """Checks if a file is an allowed extension.
29
+
30
+ Args:
31
+ filename (string): path to a file
32
+ extensions (tuple of strings): extensions to consider (lowercase)
33
+
34
+ Returns:
35
+ bool: True if the filename ends with one of given extensions
36
+ """
37
+ return filename.lower().endswith(extensions)
38
+
39
+
40
+ def is_image_file(filename: str) -> bool:
41
+ """Checks if a file is an allowed image extension.
42
+
43
+ Args:
44
+ filename (string): path to a file
45
+
46
+ Returns:
47
+ bool: True if the filename ends with a known image extension
48
+ """
49
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
50
+
51
+
52
+ def find_classes(directory: str, class_num: int) -> Tuple[List[str], Dict[str, int]]:
53
+ """Finds the class folders in a dataset.
54
+
55
+ See :class:`DatasetFolder` for details.
56
+ """
57
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
58
+ if not classes:
59
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
60
+ classes = classes[:class_num]
61
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
62
+ return classes, class_to_idx
63
+
64
+
65
+ def make_dataset(
66
+ directory: str,
67
+ class_to_idx: Optional[Dict[str, int]] = None,
68
+ extensions: Optional[Tuple[str, ...]] = None,
69
+ is_valid_file: Optional[Callable[[str], bool]] = None,
70
+ class_num=10,
71
+ ) -> List[Tuple[str, int]]:
72
+ """Generates a list of samples of a form (path_to_sample, class).
73
+
74
+ See :class:`DatasetFolder` for details.
75
+
76
+ Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
77
+ by default.
78
+ """
79
+ directory = os.path.expanduser(directory)
80
+
81
+ if class_to_idx is None:
82
+ _, class_to_idx = find_classes(directory, class_num)
83
+ elif not class_to_idx:
84
+ raise ValueError(
85
+ "'class_to_index' must have at least one entry to collect any samples."
86
+ )
87
+
88
+ both_none = extensions is None and is_valid_file is None
89
+ both_something = extensions is not None and is_valid_file is not None
90
+ if both_none or both_something:
91
+ raise ValueError(
92
+ "Both extensions and is_valid_file cannot be None or not None at the same time"
93
+ )
94
+
95
+ if extensions is not None:
96
+
97
+ def is_valid_file(x: str) -> bool:
98
+ return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
99
+
100
+ is_valid_file = cast(Callable[[str], bool], is_valid_file)
101
+
102
+ instances = []
103
+ available_classes = set()
104
+ for target_class in sorted(class_to_idx.keys()):
105
+ class_index = class_to_idx[target_class]
106
+ target_dir = os.path.join(directory, target_class)
107
+ if not os.path.isdir(target_dir):
108
+ continue
109
+ for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
110
+ for fname in sorted(fnames):
111
+ path = os.path.join(root, fname)
112
+ if is_valid_file(path):
113
+ item = path, class_index
114
+ instances.append(item)
115
+
116
+ if target_class not in available_classes:
117
+ available_classes.add(target_class)
118
+
119
+ empty_classes = set(class_to_idx.keys()) - available_classes
120
+ if empty_classes:
121
+ msg = (
122
+ f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
123
+ )
124
+ if extensions is not None:
125
+ msg += f"Supported extensions are: {', '.join(extensions)}"
126
+ raise FileNotFoundError(msg)
127
+
128
+ return instances
129
+
130
+
131
+ class DatasetFolder(VisionDataset):
132
+ """A generic data loader.
133
+
134
+ This default directory structure can be customized by overriding the
135
+ :meth:`find_classes` method.
136
+
137
+ Args:
138
+ root (string): Root directory path.
139
+ loader (callable): A function to load a sample given its path.
140
+ extensions (tuple[string]): A list of allowed extensions.
141
+ both extensions and is_valid_file should not be passed.
142
+ transform (callable, optional): A function/transform that takes in
143
+ a sample and returns a transformed version.
144
+ E.g, ``transforms.RandomCrop`` for images.
145
+ target_transform (callable, optional): A function/transform that takes
146
+ in the target and transforms it.
147
+ is_valid_file (callable, optional): A function that takes path of a file
148
+ and check if the file is a valid file (used to check of corrupt files)
149
+ both extensions and is_valid_file should not be passed.
150
+ class_num: how many classes will be loaded
151
+ Attributes:
152
+ classes (list): List of the class names sorted alphabetically.
153
+ class_to_idx (dict): Dict with items (class_name, class_index).
154
+ samples (list): List of (sample path, class_index) tuples
155
+ targets (list): The class_index value for each image in the dataset
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ root: str,
161
+ loader: Callable[[str], Any],
162
+ extensions: Optional[Tuple[str, ...]] = None,
163
+ transform: Optional[Callable] = None,
164
+ target_transform: Optional[Callable] = None,
165
+ is_valid_file: Optional[Callable[[str], bool]] = None,
166
+ class_num=10,
167
+ ) -> None:
168
+ super(DatasetFolder, self).__init__(
169
+ root, transform=transform, target_transform=target_transform
170
+ )
171
+ classes, class_to_idx = self.find_classes(self.root, class_num=class_num)
172
+ samples = self.make_dataset(
173
+ self.root, class_to_idx, extensions, is_valid_file, class_num=class_num
174
+ )
175
+
176
+ self.loader = loader
177
+ self.extensions = extensions
178
+
179
+ self.classes = classes
180
+ self.class_to_idx = class_to_idx
181
+ self.samples = samples
182
+ self.targets = [s[1] for s in samples]
183
+
184
+ @staticmethod
185
+ def make_dataset(
186
+ directory: str,
187
+ class_to_idx: Dict[str, int],
188
+ extensions: Optional[Tuple[str, ...]] = None,
189
+ is_valid_file: Optional[Callable[[str], bool]] = None,
190
+ class_num=10,
191
+ ) -> List[Tuple[str, int]]:
192
+ """Generates a list of samples of a form (path_to_sample, class).
193
+
194
+ This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
195
+
196
+ Args:
197
+ directory (str): root dataset directory, corresponding to ``self.root``.
198
+ class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
199
+ extensions (optional): A list of allowed extensions.
200
+ Either extensions or is_valid_file should be passed. Defaults to None.
201
+ is_valid_file (optional): A function that takes path of a file
202
+ and checks if the file is a valid file
203
+ (used to check of corrupt files) both extensions and
204
+ is_valid_file should not be passed. Defaults to None.
205
+ class_num: how many classes will be loaded
206
+ Raises:
207
+ ValueError: In case ``class_to_idx`` is empty.
208
+ ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
209
+ FileNotFoundError: In case no valid file was found for any class.
210
+
211
+ Returns:
212
+ List[Tuple[str, int]]: samples of a form (path_to_sample, class)
213
+ """
214
+ if class_to_idx is None:
215
+ # prevent potential bug since make_dataset() would use the class_to_idx logic of the
216
+ # find_classes() function, instead of using that of the find_classes() method, which
217
+ # is potentially overridden and thus could have a different logic.
218
+ raise ValueError("The class_to_idx parameter cannot be None.")
219
+ return make_dataset(
220
+ directory,
221
+ class_to_idx,
222
+ extensions=extensions,
223
+ is_valid_file=is_valid_file,
224
+ class_num=class_num,
225
+ )
226
+
227
+ def find_classes(
228
+ self, directory: str, class_num: int
229
+ ) -> Tuple[List[str], Dict[str, int]]:
230
+ """Find the class folders in a dataset structured as follows::
231
+
232
+ directory/
233
+ ├── class_x
234
+ │ ├── xxx.ext
235
+ │ ├── xxy.ext
236
+ │ └── ...
237
+ │ └── xxz.ext
238
+ └── class_y
239
+ ├── 123.ext
240
+ ├── nsdf3.ext
241
+ └── ...
242
+ └── asd932_.ext
243
+
244
+ This method can be overridden to only consider
245
+ a subset of classes, or to adapt to a different dataset directory structure.
246
+
247
+ Args:
248
+ directory(str): Root directory path, corresponding to ``self.root``
249
+
250
+ Raises:
251
+ FileNotFoundError: If ``dir`` has no class folders.
252
+
253
+ Returns:
254
+ (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
255
+ """
256
+ return find_classes(directory, class_num=class_num)
257
+
258
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
259
+ """
260
+ Args:
261
+ index (int): Index
262
+
263
+ Returns:
264
+ tuple: (sample, target) where target is class_index of the target class.
265
+ """
266
+ path, target = self.samples[index]
267
+ sample = self.loader(path)
268
+ if self.transform is not None:
269
+ sample = self.transform(sample)
270
+ # if self.target_transform is not None:
271
+ # target = self.target_transform(target)
272
+
273
+ return sample # , target
274
+
275
+ def __len__(self) -> int:
276
+ return len(self.samples)
277
+
278
+
279
+ IMG_EXTENSIONS = (
280
+ ".jpg",
281
+ ".jpeg",
282
+ ".png",
283
+ ".ppm",
284
+ ".bmp",
285
+ ".pgm",
286
+ ".tif",
287
+ ".tiff",
288
+ ".webp",
289
+ )
290
+
291
+
292
+ def pil_loader(path: str) -> Image.Image:
293
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
294
+ with open(path, "rb") as f:
295
+ img = Image.open(f)
296
+ return img.convert("RGB")
297
+
298
+
299
+ # TODO: specify the return type
300
+ def accimage_loader(path: str) -> Any:
301
+ import accimage
302
+
303
+ try:
304
+ return accimage.Image(path)
305
+ except IOError:
306
+ # Potentially a decoding problem, fall back to PIL.Image
307
+ return pil_loader(path)
308
+
309
+
310
+ def default_loader(path: str) -> Any:
311
+ from torchvision import get_image_backend
312
+
313
+ if get_image_backend() == "accimage":
314
+ return accimage_loader(path)
315
+ else:
316
+ return pil_loader(path)
317
+
318
+
319
+ class ImageFolder(DatasetFolder):
320
+ """A generic data loader where the images are arranged in this way by default: ::
321
+
322
+ root/dog/xxx.png
323
+ root/dog/xxy.png
324
+ root/dog/[...]/xxz.png
325
+
326
+ root/cat/123.png
327
+ root/cat/nsdf3.png
328
+ root/cat/[...]/asd932_.png
329
+
330
+ This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
331
+ the same methods can be overridden to customize the dataset.
332
+
333
+ Args:
334
+ root (string): Root directory path.
335
+ transform (callable, optional): A function/transform that takes in an PIL image
336
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
337
+ target_transform (callable, optional): A function/transform that takes in the
338
+ target and transforms it.
339
+ loader (callable, optional): A function to load an image given its path.
340
+ is_valid_file (callable, optional): A function that takes path of an Image file
341
+ and check if the file is a valid file (used to check of corrupt files)
342
+ class_num: how many classes will be loaded
343
+ Attributes:
344
+ classes (list): List of the class names sorted alphabetically.
345
+ class_to_idx (dict): Dict with items (class_name, class_index).
346
+ imgs (list): List of (image path, class_index) tuples
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ root: str,
352
+ transform: Optional[Callable] = None,
353
+ target_transform: Optional[Callable] = None,
354
+ loader: Callable[[str], Any] = default_loader,
355
+ is_valid_file: Optional[Callable[[str], bool]] = None,
356
+ class_num=10,
357
+ ):
358
+ super(ImageFolder, self).__init__(
359
+ root,
360
+ loader,
361
+ IMG_EXTENSIONS if is_valid_file is None else None,
362
+ transform=transform,
363
+ target_transform=target_transform,
364
+ is_valid_file=is_valid_file,
365
+ class_num=class_num,
366
+ )
367
+ self.imgs = self.samples
src/loss.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ functions for building multi-granular losses.
16
+ """
17
+ import numpy as np
18
+ import torch
19
+ import torch.distributed as dist
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from utils import concat_all_gather
24
+
25
+
26
+ class InfoNCELoss(nn.Module):
27
+ """
28
+ vanilla infoNCEloss.
29
+ --ncrops: how many crops are used in student networks
30
+ --dim: feature dimension in queue determinted by output dimention of student network
31
+ --queue_size: queue size
32
+ --temperature: temperature parameter for infoNCEloss
33
+ """
34
+
35
+ def __init__(self, ncrops, dim=256, queue_size=65536, temperature=0.2):
36
+ super().__init__()
37
+ self.queue_size = queue_size
38
+ self.temperature = temperature
39
+
40
+ self.register_buffer("queue", torch.randn(dim, queue_size))
41
+ self.queue = nn.functional.normalize(self.queue, dim=0)
42
+
43
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
44
+ self.CrossEntropyLoss = nn.CrossEntropyLoss()
45
+ self.ncrops = ncrops
46
+
47
+ @torch.no_grad()
48
+ def _dequeue_and_enqueue(self, keys):
49
+ """
50
+ queue update
51
+ """
52
+ keys = concat_all_gather(keys)
53
+ batch_size = keys.shape[0]
54
+ ptr = int(self.queue_ptr)
55
+ # replace the keys at ptr (dequeue and enqueue)
56
+ if ptr + batch_size <= self.queue_size:
57
+ self.queue[:, ptr : ptr + batch_size] = keys.T
58
+ ptr = (ptr + batch_size) % self.queue_size
59
+ else:
60
+ keys_t = keys.T
61
+ queue_remaining_size = self.queue_size - ptr
62
+ self.queue[:, ptr:] = keys_t[:, :queue_remaining_size]
63
+ self.queue[:, : batch_size - queue_remaining_size] = keys_t[
64
+ :, queue_remaining_size:
65
+ ]
66
+
67
+ ptr = batch_size - queue_remaining_size # move pointer
68
+
69
+ self.queue_ptr[0] = ptr
70
+
71
+ # student_output, teacher_output
72
+ def forward(self, student_output, teacher_output, epoch):
73
+ """
74
+ Cross-entropy between softmax outputs of the teacher and student networks.
75
+ """
76
+ preds = student_output.chunk(self.ncrops)
77
+ targets = teacher_output.detach().chunk(2)
78
+ small_crop_loss, large_crop_loss = 0, 0
79
+ small_loss_terms, large_loss_terms = 0, 0
80
+ queue_feat = self.queue.clone().detach()
81
+
82
+ for t_idx, targ in enumerate(targets):
83
+ for p_idx, pred in enumerate(preds):
84
+ if t_idx == p_idx:
85
+ continue
86
+ # positive logits: Nx1
87
+ l_pos = torch.einsum("nc,nc->n", [pred, targ]).unsqueeze(-1)
88
+ # negative logits: NxK
89
+ l_neg = torch.einsum("nc,ck->nk", [pred, queue_feat])
90
+ # logits: Nx(1+K)
91
+ logits = torch.cat([l_pos, l_neg], dim=1)
92
+ # apply temperature
93
+ logits /= self.temperature
94
+ # labels: positive key indicators
95
+ labels = torch.zeros(logits.shape[0], dtype=torch.long).to(
96
+ logits.device
97
+ )
98
+ loss = self.CrossEntropyLoss(logits, labels)
99
+ if p_idx < 2: ## large crop loss, namely loss on 224-sized images
100
+ large_crop_loss += loss
101
+ large_loss_terms += 1
102
+ else: ## small crop loss, namely loss on 96-sized images
103
+ small_crop_loss += loss
104
+ small_loss_terms += 1
105
+ # dequeue and enqueue
106
+ self._dequeue_and_enqueue(targ)
107
+
108
+ large_crop_loss /= large_loss_terms
109
+ small_crop_loss /= small_loss_terms
110
+ loss = 0.5 * (large_crop_loss + small_crop_loss)
111
+ return loss
112
+
113
+
114
+ class ClusteringLoss(nn.Module):
115
+ """
116
+ Clustering loss which is very simialr to the one in DINO
117
+ --out_dim: center dimension determinted by output dimention of student network
118
+ --ncrops: how many crops are used in student networks
119
+ --warmup_teacher_temp: Initial value for the teacher temperature
120
+ --teacher_temp: Final value (after linear warmup) of the teacher temperature
121
+ --warmup_teacher_temp_epochs: Number of warmup epochs for the teacher temperature
122
+ --nepochs: total training epoch
123
+ --student_temp: temperature parameter in student output
124
+ --center_momentum: EMA parameter for center update
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ out_dim,
130
+ ncrops,
131
+ warmup_teacher_temp,
132
+ teacher_temp,
133
+ warmup_teacher_temp_epochs,
134
+ nepochs,
135
+ student_temp=0.1,
136
+ center_momentum=0.9,
137
+ ):
138
+ super().__init__()
139
+ self.student_temp = student_temp
140
+ self.center_momentum = center_momentum
141
+ self.ncrops = ncrops
142
+ self.register_buffer("center", torch.zeros(1, out_dim))
143
+ # we apply a warm up for the teacher temperature because
144
+ # a too high temperature makes the training instable at the beginning
145
+ self.teacher_temp_schedule = np.concatenate(
146
+ (
147
+ np.linspace(
148
+ warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs
149
+ ),
150
+ np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp,
151
+ )
152
+ )
153
+
154
+ def forward(self, student_output, teacher_output, epoch):
155
+ """
156
+ Cross-entropy between softmax outputs of the teacher and student networks.
157
+ """
158
+ student_out = student_output / self.student_temp
159
+ student_out = student_out.chunk(self.ncrops)
160
+
161
+ # teacher centering and sharpening
162
+ temp = self.teacher_temp_schedule[epoch]
163
+ teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
164
+ teacher_out = teacher_out.detach().chunk(2)
165
+
166
+ loss_large_crop, loss_small_crop = 0.0, 0.0
167
+ loss_terms_large_crop, loss_terms_small_crop = 0, 0
168
+ for iq, q in enumerate(teacher_out):
169
+ for v in range(len(student_out)):
170
+ if v == iq:
171
+ # we skip cases where student and teacher operate on the same view
172
+ continue
173
+ loss = torch.sum(
174
+ -q * F.log_softmax(student_out[v], dim=-1), dim=-1
175
+ ).mean()
176
+ if v < 2:
177
+ loss_large_crop += loss
178
+ loss_terms_large_crop += 1
179
+ else:
180
+ loss_small_crop += loss
181
+ loss_terms_small_crop += 1
182
+
183
+ self.update_center(teacher_output)
184
+ loss_large_crop /= loss_terms_large_crop
185
+ loss_small_crop /= loss_terms_small_crop
186
+ total_loss = 0.5 * (loss_large_crop + loss_small_crop)
187
+ return total_loss
188
+
189
+ @torch.no_grad()
190
+ def update_center(self, teacher_output):
191
+ """
192
+ Update center used for teacher output.
193
+ """
194
+ batch_center = torch.mean(teacher_output, dim=0, keepdim=False)
195
+ dist.all_reduce(batch_center)
196
+ batch_center = batch_center / dist.get_world_size()
197
+
198
+ # ema update
199
+ self.center = self.center * self.center_momentum + batch_center * (
200
+ 1 - self.center_momentum
201
+ )
202
+
203
+
204
+ def get_multi_granular_loss(args):
205
+ """
206
+ build the multi-granular loss
207
+ """
208
+ all_losses, all_weights = {}, {}
209
+
210
+ ## build the instance discrimination loss
211
+ instance_supervision_loss = InfoNCELoss(
212
+ args.local_crops_number + 2,
213
+ dim=args.instance_out_dim,
214
+ queue_size=args.instance_queue_size,
215
+ temperature=args.instance_temp,
216
+ ).cuda()
217
+ all_losses["instance-sup."] = instance_supervision_loss
218
+ all_weights["instance-sup."] = args.loss_weights[0]
219
+
220
+ ## build the local group discrimination loss
221
+ local_group_supervision = InfoNCELoss(
222
+ args.local_crops_number + 2,
223
+ dim=args.local_group_out_dim,
224
+ queue_size=args.local_group_queue_size,
225
+ temperature=args.local_group_temp,
226
+ ).cuda()
227
+ all_losses["local-group-sup."] = local_group_supervision
228
+ all_weights["local-group-sup."] = args.loss_weights[1]
229
+
230
+ ## build the group discrimination loss
231
+ group_loss = ClusteringLoss(
232
+ args.group_out_dim,
233
+ args.local_crops_number
234
+ + 2, # total number of crops = 2 global crops + local_crops_number
235
+ args.group_warmup_teacher_temp,
236
+ args.group_teacher_temp,
237
+ args.group_warmup_teacher_temp_epochs,
238
+ args.epochs,
239
+ student_temp=args.group_student_temp,
240
+ center_momentum=0.9,
241
+ ).cuda()
242
+ all_losses["group-sup."] = group_loss
243
+ all_weights["group-sup."] = args.loss_weights[2]
244
+ return all_losses, all_weights
src/model.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ models and functions for building student and teacher networks for multi-granular losses.
16
+ """
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ import src.vision_transformer as vits
21
+ from src.vision_transformer import trunc_normal_
22
+
23
+
24
+ class Instance_Superivsion_Head(nn.Module):
25
+ """
26
+ a class to implement Instance Superivsion Head
27
+ --in_dim: input dimension of projection head
28
+ --hidden_dim: hidden dimension of projection head
29
+ --out_dim: ouput dimension of projection and prediction heads
30
+ --pred_hidden_dim: hidden dimension of prediction head
31
+ --nlayers: layer number of projection head. prediction head has nlayers-1 layer
32
+ --proj_bn: whether we use batch normalization in projection head
33
+ --pred_bn: whether we use batch normalization in prediction head
34
+ --norm_before_pred: whether we use normalization before prediction head
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ in_dim,
40
+ hidden_dim=2048,
41
+ out_dim=256,
42
+ pred_hidden_dim=4096,
43
+ nlayers=3,
44
+ proj_bn=False,
45
+ pred_bn=False,
46
+ norm_before_pred=True,
47
+ ):
48
+ super().__init__()
49
+ nlayers = max(nlayers, 1)
50
+ self.norm_before_pred = norm_before_pred
51
+
52
+ self.projector = self._build_mlp(
53
+ nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
54
+ )
55
+
56
+ self.apply(self._init_weights)
57
+
58
+ self.predictor = None
59
+ if pred_hidden_dim > 0: # teacher no, student yes
60
+ self.predictor = self._build_mlp(
61
+ nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
62
+ )
63
+
64
+ def _init_weights(self, m):
65
+ """
66
+ initilize the parameters in network
67
+ """
68
+ if isinstance(m, nn.Linear):
69
+ trunc_normal_(m.weight, std=0.02)
70
+ if isinstance(m, nn.Linear) and m.bias is not None:
71
+ nn.init.constant_(m.bias, 0)
72
+
73
+ def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
74
+ """
75
+ build a mlp
76
+ """
77
+ mlp = []
78
+ for layer in range(num_layers):
79
+ dim1 = input_dim if layer == 0 else hidden_dim
80
+ dim2 = output_dim if layer == num_layers - 1 else hidden_dim
81
+
82
+ mlp.append(nn.Linear(dim1, dim2, bias=False))
83
+
84
+ if layer < num_layers - 1:
85
+ if use_bn:
86
+ mlp.append(nn.BatchNorm1d(dim2))
87
+ mlp.append(nn.GELU())
88
+
89
+ return nn.Sequential(*mlp)
90
+
91
+ def forward(self, x, return_target=False):
92
+ """
93
+ forward the input through projection head for teacher and
94
+ projection/prediction heads for student
95
+ """
96
+ feat = self.projector(x)
97
+
98
+ if return_target:
99
+ feat = nn.functional.normalize(feat, dim=-1, p=2)
100
+ return feat
101
+ ## return prediction
102
+ if self.norm_before_pred:
103
+ feat = nn.functional.normalize(feat, dim=-1, p=2)
104
+ pred = self.predictor(feat)
105
+ pred = nn.functional.normalize(pred, dim=-1, p=2)
106
+ return pred
107
+
108
+
109
+ class Local_Group_Superivsion_Head(nn.Module):
110
+ """
111
+ a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
112
+ --in_dim: input dimension of projection head
113
+ --hidden_dim: hidden dimension of projection head
114
+ --out_dim: ouput dimension of projection and prediction heads
115
+ --pred_hidden_dim: hidden dimension of prediction head
116
+ --nlayers: layer number of projection head. prediction head has nlayers-1 layer
117
+ --proj_bn: whether we use batch normalization in projection head
118
+ --pred_bn: whether we use batch normalization in prediction head
119
+ --norm_before_pred: whether we use normalization before prediction head
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ in_dim,
125
+ hidden_dim=2048,
126
+ out_dim=256,
127
+ pred_hidden_dim=4096,
128
+ nlayers=3,
129
+ proj_bn=False,
130
+ pred_bn=False,
131
+ norm_before_pred=True,
132
+ ):
133
+ super().__init__()
134
+ nlayers = max(nlayers, 1)
135
+ self.norm_before_pred = norm_before_pred
136
+
137
+ self.projector = self._build_mlp(
138
+ nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
139
+ )
140
+
141
+ self.apply(self._init_weights)
142
+
143
+ self.predictor = None
144
+ if pred_hidden_dim > 0: # teacher no, student yes
145
+ self.predictor = self._build_mlp(
146
+ nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
147
+ )
148
+
149
+ def _init_weights(self, m):
150
+ """
151
+ initilize the parameters in network
152
+ """
153
+ if isinstance(m, nn.Linear):
154
+ trunc_normal_(m.weight, std=0.02)
155
+ if isinstance(m, nn.Linear) and m.bias is not None:
156
+ nn.init.constant_(m.bias, 0)
157
+
158
+ def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
159
+ """
160
+ build a mlp
161
+ """
162
+ mlp = []
163
+ for layer in range(num_layers):
164
+ dim1 = input_dim if layer == 0 else hidden_dim
165
+ dim2 = output_dim if layer == num_layers - 1 else hidden_dim
166
+
167
+ mlp.append(nn.Linear(dim1, dim2, bias=False))
168
+
169
+ if layer < num_layers - 1:
170
+ if use_bn:
171
+ mlp.append(nn.BatchNorm1d(dim2))
172
+ mlp.append(nn.GELU())
173
+
174
+ return nn.Sequential(*mlp)
175
+
176
+ def forward(self, x, return_target=False):
177
+ """
178
+ forward the input through projection head for teacher and
179
+ projection/prediction heads for student
180
+ """
181
+ feat = self.projector(x)
182
+
183
+ if return_target:
184
+ feat = nn.functional.normalize(feat, dim=-1, p=2)
185
+ return feat
186
+ ## return prediction
187
+ if self.norm_before_pred:
188
+ feat = nn.functional.normalize(feat, dim=-1, p=2)
189
+ pred = self.predictor(feat)
190
+ pred = nn.functional.normalize(pred, dim=-1, p=2)
191
+ return pred
192
+
193
+
194
+ class Group_Superivsion_Head(nn.Module):
195
+ """
196
+ a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
197
+ --in_dim: input dimension of projection head
198
+ --hidden_dim: hidden dimension of projection head
199
+ --out_dim: ouput dimension of projection and prediction heads
200
+ --pred_hidden_dim: hidden dimension of prediction head
201
+ --nlayers: layer number of projection head. prediction head has nlayers-1 layer
202
+ --proj_bn: whether we use batch normalization in projection head
203
+ --pred_bn: whether we use batch normalization in prediction head
204
+ --norm_before_pred: whether we use normalization before prediction head
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ in_dim,
210
+ out_dim,
211
+ hidden_dim=2048,
212
+ bottleneck_dim=256,
213
+ nlayers=3,
214
+ use_bn=False,
215
+ norm_last_layer=True,
216
+ ):
217
+ super().__init__()
218
+ nlayers = max(nlayers, 1)
219
+
220
+ self.projector = self._build_mlp(
221
+ nlayers, in_dim, hidden_dim, bottleneck_dim, use_bn=use_bn
222
+ )
223
+ self.apply(self._init_weights)
224
+
225
+ self.last_layer = nn.utils.weight_norm(
226
+ nn.Linear(bottleneck_dim, out_dim, bias=False)
227
+ )
228
+ self.last_layer.weight_g.data.fill_(1)
229
+ if norm_last_layer:
230
+ self.last_layer.weight_g.requires_grad = False
231
+
232
+ def _build_mlp(self, num_layers, in_dim, hidden_dim, output_dim, use_bn=False):
233
+ """
234
+ build a mlp
235
+ """
236
+ if num_layers == 1:
237
+ mlp = nn.Linear(in_dim, output_dim)
238
+ else:
239
+ layers = [nn.Linear(in_dim, hidden_dim)]
240
+ if use_bn:
241
+ layers.append(nn.BatchNorm1d(hidden_dim))
242
+ layers.append(nn.GELU())
243
+ for _ in range(num_layers - 2):
244
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
245
+ if use_bn:
246
+ layers.append(nn.BatchNorm1d(hidden_dim))
247
+ layers.append(nn.GELU())
248
+ layers.append(nn.Linear(hidden_dim, output_dim))
249
+ mlp = nn.Sequential(*layers)
250
+ return mlp
251
+
252
+ def _init_weights(self, m):
253
+ """
254
+ initilize the parameters in network
255
+ """
256
+ if isinstance(m, nn.Linear):
257
+ trunc_normal_(m.weight, std=0.02)
258
+ if isinstance(m, nn.Linear) and m.bias is not None:
259
+ nn.init.constant_(m.bias, 0)
260
+
261
+ def forward(self, x):
262
+ """
263
+ forward the input through the projection and last prediction layer
264
+ """
265
+ feat = self.projector(x)
266
+ feat = nn.functional.normalize(feat, dim=-1, p=2)
267
+ feat = self.last_layer(feat)
268
+ return feat
269
+
270
+
271
+ class Block_mem(nn.Module):
272
+ """
273
+ a class to implement a memory block for local group supervision
274
+ --dim: feature vector dimenstion in the memory
275
+ --K: memory size
276
+ --top_n: number for neighbors in local group supervision
277
+ """
278
+
279
+ def __init__(self, dim, K=2048, top_n=10):
280
+ super().__init__()
281
+ self.dim = dim
282
+ self.K = K
283
+ self.top_n = top_n
284
+ # create the queue
285
+ self.register_buffer("queue_q", torch.randn(K, dim))
286
+ self.register_buffer("queue_k", torch.randn(K, dim))
287
+ self.register_buffer("queue_v", torch.randn(K, dim))
288
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
289
+
290
+ @torch.no_grad()
291
+ def _dequeue_and_enqueue(self, query, weak_aug_flags):
292
+ """
293
+ update memory queue
294
+ """
295
+ # import pdb
296
+ # pdb.set_trace()
297
+ len_weak = 0
298
+ query = concat_all_gather(query)
299
+ if weak_aug_flags is not None:
300
+ weak_aug_flags = weak_aug_flags.cuda()
301
+ weak_aug_flags = concat_all_gather(weak_aug_flags)
302
+ idx_weak = torch.nonzero(weak_aug_flags)
303
+ len_weak = len(idx_weak)
304
+ if len_weak > 0:
305
+ idx_weak = idx_weak.squeeze(-1)
306
+ query = query[idx_weak]
307
+ else:
308
+ return len_weak
309
+
310
+ all_size = query.shape[0]
311
+ ptr = int(self.queue_ptr)
312
+ remaining_size = ptr + all_size - self.K
313
+ if remaining_size <= 0:
314
+ self.queue_q[ptr : ptr + all_size, :] = query
315
+ self.queue_k[ptr : ptr + all_size, :] = query
316
+ self.queue_v[ptr : ptr + all_size, :] = query
317
+ ptr = ptr + all_size
318
+ self.queue_ptr[0] = (ptr + all_size) % self.K
319
+ else:
320
+ self.queue_q[ptr : self.K, :] = query[0 : self.K - ptr, :]
321
+ self.queue_k[ptr : self.K, :] = query[0 : self.K - ptr, :]
322
+ self.queue_v[ptr : self.K, :] = query[0 : self.K - ptr, :]
323
+
324
+ self.queue_q[0:remaining_size, :] = query[self.K - ptr :, :]
325
+ self.queue_k[0:remaining_size, :] = query[self.K - ptr :, :]
326
+ self.queue_v[0:remaining_size, :] = query[self.K - ptr :, :]
327
+ self.queue_ptr[0] = remaining_size
328
+ return len_weak
329
+
330
+ @torch.no_grad()
331
+ def _get_similarity_index(self, x):
332
+ """
333
+ compute the index of the top-n neighbors (key-value pair) in memory
334
+ """
335
+ x = nn.functional.normalize(x, dim=-1)
336
+ queue_q = nn.functional.normalize(self.queue_q, dim=-1)
337
+
338
+ cosine = x @ queue_q.T
339
+ _, index = torch.topk(cosine, self.top_n, dim=-1)
340
+ return index
341
+
342
+ @torch.no_grad()
343
+ def _get_similarity_samples(self, query, index=None):
344
+ """
345
+ compute top-n neighbors (key-value pair) in memory
346
+ """
347
+ if index is None:
348
+ index = self._get_similarity_index(query)
349
+ get_k = self.queue_k[index.view(-1)]
350
+ get_v = self.queue_v[index.view(-1)]
351
+ B, tn = index.shape
352
+ get_k = get_k.view(B, tn, self.dim)
353
+ get_v = get_v.view(B, tn, self.dim)
354
+ return get_k, get_v
355
+
356
+ def forward(self, query):
357
+ """
358
+ forward to find the top-n neighbors (key-value pair) in memory
359
+ """
360
+ get_k, get_v = self._get_similarity_samples(query)
361
+ return get_k, get_v
362
+
363
+
364
+ class vit_mem(nn.Module):
365
+ """
366
+ a class to implement a memory for local group supervision
367
+ --dim: feature vector dimenstion in the memory
368
+ --K: memory size
369
+ --top_n: number for neighbors in local group supervision
370
+ """
371
+
372
+ def __init__(self, dim, K=2048, top_n=10):
373
+ super().__init__()
374
+ self.block = Block_mem(dim, K, top_n)
375
+
376
+ def _dequeue_and_enqueue(self, query, weak_aug_flags):
377
+ """
378
+ update memory queue
379
+ """
380
+ query = query.float()
381
+ weak_num = self.block._dequeue_and_enqueue(query, weak_aug_flags)
382
+ return weak_num
383
+
384
+ def forward(self, query):
385
+ """
386
+ forward to find the top-n neighbors (key-value pair) in memory
387
+ """
388
+ query = query.float()
389
+ get_k, get_v = self.block(query)
390
+ return get_k, get_v
391
+
392
+
393
+ class Mugs_Wrapper(nn.Module):
394
+ """
395
+ a class to implement a student or teacher wrapper for mugs
396
+ --backbone: the backnone of student/teacher, e.g. ViT-small
397
+ --instance_head: head, including projection/prediction heads, for instance supervision
398
+ --local_group_head: head, including projection/prediction heads, for local group supervision
399
+ --group_head: projection head for group supervision
400
+ """
401
+
402
+ def __init__(self, backbone, instance_head, local_group_head, group_head):
403
+ super(Mugs_Wrapper, self).__init__()
404
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
405
+ self.backbone = backbone
406
+ self.instance_head = instance_head
407
+ self.local_group_head = local_group_head
408
+ self.group_head = group_head
409
+
410
+ def forward(self, x, return_target=False, local_group_memory_inputs=None):
411
+ """
412
+ forward input to get instance/local-group/group targets or predictions
413
+ """
414
+ # convert to list
415
+ if not isinstance(x, list):
416
+ x = [x]
417
+ idx_crops = torch.cumsum(
418
+ torch.unique_consecutive(
419
+ torch.tensor([inp.shape[-1] for inp in x]),
420
+ return_counts=True,
421
+ )[1],
422
+ 0,
423
+ )
424
+
425
+ start_idx = 0
426
+ class_tokens = torch.empty(0).to(x[0].device)
427
+ mean_patch_tokens = torch.empty(0).to(x[0].device)
428
+ memory_class_tokens = torch.empty(0).to(x[0].device)
429
+ for _, end_idx in enumerate(idx_crops):
430
+ input = torch.cat(x[start_idx:end_idx])
431
+ token_feat, memory_class_token_feat = self.backbone(
432
+ input,
433
+ return_all=True,
434
+ local_group_memory_inputs=local_group_memory_inputs,
435
+ ) # [[16, 197, 384], [16, 384]] teacher
436
+ # [[16, 197, 384], [16, 384]] student [[48, 37, 384], [48, 384]]
437
+
438
+ class_token_feat = token_feat[
439
+ :, 0
440
+ ] # class tokens in ViT, [16, 384] teacher [16, 384] student [48, 384]
441
+ class_tokens = torch.cat((class_tokens, class_token_feat))
442
+
443
+ start_idx = end_idx
444
+
445
+ if self.local_group_head is not None:
446
+ memory_class_tokens = torch.cat(
447
+ (memory_class_tokens, memory_class_token_feat)
448
+ )
449
+ if input.shape[-1] == 224:
450
+ mean_patch_tokens = torch.cat(
451
+ (mean_patch_tokens, token_feat[:, 1:].mean(dim=1))
452
+ )
453
+
454
+ ## target [16, 256] for teacher, [64, 256] for student,
455
+ instance_feat = (
456
+ self.instance_head(class_tokens, return_target)
457
+ if self.instance_head is not None
458
+ else None
459
+ )
460
+
461
+ ## target [16, 256] for teacher, [64, 256] for student
462
+ local_group_feat = (
463
+ self.local_group_head(memory_class_tokens, return_target)
464
+ if self.local_group_head is not None
465
+ else None
466
+ )
467
+
468
+ # target [16, 65536] for teacher, [64, 65536] for student
469
+ group_feat = (
470
+ self.group_head(class_tokens) if self.group_head is not None else None
471
+ )
472
+ return instance_feat, local_group_feat, group_feat, mean_patch_tokens.detach()
473
+
474
+
475
+ def get_model(args):
476
+ """
477
+ build a student or teacher for mugs, includeing backbone, instance/local-group/group heads,
478
+ and memory buffer
479
+ """
480
+ ## backbone
481
+ if args.arch in vits.__dict__.keys():
482
+ student = vits.__dict__[args.arch](
483
+ patch_size=args.patch_size,
484
+ num_relation_blocks=1,
485
+ drop_path_rate=args.drop_path_rate, # stochastic depth
486
+ )
487
+ teacher = vits.__dict__[args.arch](
488
+ patch_size=args.patch_size, num_relation_blocks=1
489
+ )
490
+ embed_dim = student.embed_dim
491
+ else:
492
+ assert f"Unknow architecture: {args.arch}"
493
+
494
+ ## memory buffer for local-group loss
495
+ student_mem = vit_mem(
496
+ embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
497
+ )
498
+ teacher_mem = vit_mem(
499
+ embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
500
+ )
501
+
502
+ ## multi-crop wrapper handles forward with inputs of different resolutions
503
+ student_instance_head, student_local_group_head, student_group_head = (
504
+ None,
505
+ None,
506
+ None,
507
+ )
508
+ teacher_instance_head, teacher_local_group_head, teacher_group_head = (
509
+ None,
510
+ None,
511
+ None,
512
+ )
513
+
514
+ # instance head
515
+ if args.loss_weights[0] > 0:
516
+ student_instance_head = Instance_Superivsion_Head(
517
+ in_dim=embed_dim,
518
+ hidden_dim=2048,
519
+ out_dim=args.instance_out_dim,
520
+ pred_hidden_dim=4096,
521
+ nlayers=3,
522
+ proj_bn=args.use_bn_in_head,
523
+ pred_bn=False,
524
+ norm_before_pred=args.norm_before_pred,
525
+ )
526
+ teacher_instance_head = Instance_Superivsion_Head(
527
+ in_dim=embed_dim,
528
+ hidden_dim=2048,
529
+ out_dim=args.instance_out_dim,
530
+ pred_hidden_dim=0,
531
+ nlayers=3,
532
+ proj_bn=args.use_bn_in_head,
533
+ pred_bn=False,
534
+ norm_before_pred=args.norm_before_pred,
535
+ )
536
+
537
+ # local group head
538
+ if args.loss_weights[1] > 0:
539
+ student_local_group_head = Local_Group_Superivsion_Head(
540
+ in_dim=embed_dim,
541
+ hidden_dim=2048,
542
+ out_dim=args.local_group_out_dim,
543
+ pred_hidden_dim=4096,
544
+ nlayers=3,
545
+ proj_bn=args.use_bn_in_head,
546
+ pred_bn=False,
547
+ norm_before_pred=args.norm_before_pred,
548
+ )
549
+ teacher_local_group_head = Local_Group_Superivsion_Head(
550
+ in_dim=embed_dim,
551
+ hidden_dim=2048,
552
+ out_dim=args.local_group_out_dim,
553
+ pred_hidden_dim=0,
554
+ nlayers=3,
555
+ proj_bn=args.use_bn_in_head,
556
+ pred_bn=False,
557
+ norm_before_pred=args.norm_before_pred,
558
+ )
559
+
560
+ # group head
561
+ if args.loss_weights[2] > 0:
562
+ student_group_head = Group_Superivsion_Head(
563
+ in_dim=embed_dim,
564
+ out_dim=args.group_out_dim,
565
+ hidden_dim=2048,
566
+ bottleneck_dim=args.group_bottleneck_dim,
567
+ nlayers=3,
568
+ use_bn=args.use_bn_in_head,
569
+ norm_last_layer=args.norm_last_layer,
570
+ )
571
+ teacher_group_head = Group_Superivsion_Head(
572
+ in_dim=embed_dim,
573
+ out_dim=args.group_out_dim,
574
+ hidden_dim=2048,
575
+ bottleneck_dim=args.group_bottleneck_dim,
576
+ nlayers=3,
577
+ use_bn=args.use_bn_in_head,
578
+ norm_last_layer=args.norm_last_layer,
579
+ )
580
+
581
+ # multi-crop wrapper
582
+ student = Mugs_Wrapper(
583
+ student, student_instance_head, student_local_group_head, student_group_head
584
+ )
585
+
586
+ teacher = Mugs_Wrapper(
587
+ teacher, teacher_instance_head, teacher_local_group_head, teacher_group_head
588
+ )
589
+
590
+ return student, teacher, student_mem, teacher_mem
591
+
592
+
593
+ # utils
594
+ @torch.no_grad()
595
+ def concat_all_gather(tensor):
596
+ """
597
+ Performs all_gather operation on the provided tensors.
598
+ *** Warning ***: torch.distributed.all_gather has no gradient.
599
+ """
600
+
601
+ tensors_gather = [
602
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
603
+ ]
604
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
605
+
606
+ output = torch.cat(tensors_gather, dim=0)
607
+ return output
src/multicropdataset.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ multi-crop dataset to implement multi-crop augmentation and also dataset
16
+ """
17
+ import copy
18
+ import random
19
+
20
+ import torch
21
+ import torchvision.transforms as transforms
22
+ from PIL import Image, ImageFilter, ImageOps
23
+ from src.dataset import ImageFolder
24
+ from src.RandAugment import rand_augment_transform
25
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
26
+ from timm.data.random_erasing import RandomErasing
27
+ from timm.data.transforms import _pil_interp
28
+
29
+
30
+ class GaussianBlur(object):
31
+ """
32
+ Apply Gaussian Blur to the PIL image.
33
+ """
34
+
35
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
36
+ self.prob = p
37
+ self.radius_min = radius_min
38
+ self.radius_max = radius_max
39
+
40
+ def __call__(self, img):
41
+ do_it = random.random() <= self.prob
42
+ if not do_it:
43
+ return img
44
+
45
+ return img.filter(
46
+ ImageFilter.GaussianBlur(
47
+ radius=random.uniform(self.radius_min, self.radius_max)
48
+ )
49
+ )
50
+
51
+
52
+ class Solarization(object):
53
+ """
54
+ Apply Solarization to the PIL image.
55
+ """
56
+
57
+ def __init__(self, p):
58
+ self.p = p
59
+
60
+ def __call__(self, img):
61
+ if random.random() < self.p:
62
+ return ImageOps.solarize(img)
63
+ else:
64
+ return img
65
+
66
+
67
+ def strong_transforms(
68
+ img_size=224,
69
+ scale=(0.08, 1.0),
70
+ ratio=(0.75, 1.3333333333333333),
71
+ hflip=0.5,
72
+ vflip=0.0,
73
+ color_jitter=0.4,
74
+ auto_augment="rand-m9-mstd0.5-inc1",
75
+ interpolation="random",
76
+ use_prefetcher=True,
77
+ mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
78
+ std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
79
+ re_prob=0.25,
80
+ re_mode="pixel",
81
+ re_count=1,
82
+ re_num_splits=0,
83
+ color_aug=False,
84
+ strong_ratio=0.45,
85
+ ):
86
+ """
87
+ for use in a mixing dataset that passes
88
+ * all data through the first (primary) transform, called the 'clean' data
89
+ * a portion of the data through the secondary transform
90
+ * normalizes and converts the branches above with the third, final transform
91
+ """
92
+
93
+ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
94
+ ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
95
+
96
+ primary_tfl = []
97
+ if hflip > 0.0:
98
+ primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
99
+ if vflip > 0.0:
100
+ primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
101
+
102
+ secondary_tfl = []
103
+ if auto_augment:
104
+ assert isinstance(auto_augment, str)
105
+ if isinstance(img_size, tuple):
106
+ img_size_min = min(img_size)
107
+ else:
108
+ img_size_min = img_size
109
+ aa_params = dict(
110
+ translate_const=int(img_size_min * strong_ratio),
111
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
112
+ )
113
+ if interpolation and interpolation != "random":
114
+ aa_params["interpolation"] = _pil_interp(interpolation)
115
+ if auto_augment.startswith("rand"):
116
+ secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
117
+ if color_jitter is not None and color_aug:
118
+ # color jitter is enabled when not using AA
119
+ flip_and_color_jitter = [
120
+ transforms.RandomApply(
121
+ [
122
+ transforms.ColorJitter(
123
+ brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
124
+ )
125
+ ],
126
+ p=0.8,
127
+ ),
128
+ transforms.RandomGrayscale(p=0.2),
129
+ ]
130
+ secondary_tfl += flip_and_color_jitter
131
+
132
+ if interpolation == "random":
133
+ interpolation = (Image.BILINEAR, Image.BICUBIC)
134
+ else:
135
+ interpolation = _pil_interp(interpolation)
136
+ final_tfl = [
137
+ transforms.RandomResizedCrop(
138
+ size=img_size, scale=scale, ratio=ratio, interpolation=Image.BICUBIC
139
+ )
140
+ ]
141
+ if use_prefetcher:
142
+ # prefetcher and collate will handle tensor conversion and norm
143
+ final_tfl += [transforms.ToTensor()]
144
+ else:
145
+ final_tfl += [
146
+ transforms.ToTensor(),
147
+ transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
148
+ ]
149
+ if re_prob > 0.0:
150
+ final_tfl.append(
151
+ RandomErasing(
152
+ re_prob,
153
+ mode=re_mode,
154
+ max_count=re_count,
155
+ num_splits=re_num_splits,
156
+ device="cpu",
157
+ )
158
+ )
159
+ return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
160
+
161
+
162
+ class DataAugmentation(object):
163
+ """
164
+ implement multi-crop data augmentation.
165
+ --global_crops_scale: scale range of the 224-sized cropped image before resizing
166
+ --local_crops_scale: scale range of the 96-sized cropped image before resizing
167
+ --local_crops_number: Number of small local views to generate
168
+ --prob: when we use strong augmentation and weak augmentation, the ratio of images to
169
+ be cropped with strong augmentation
170
+ --vanilla_weak_augmentation: whether we use the same augmentation in DINO, namely
171
+ only using weak augmentation
172
+ --color_aug: after AutoAugment, whether we further perform color augmentation
173
+ --local_crop_size: the small crop size
174
+ --timm_auto_augment_par: the parameters for the AutoAugment used in DeiT
175
+ --strong_ratio: the ratio of image augmentation for the AutoAugment used in DeiT
176
+ --re_prob: the re-prob parameter of image augmentation for the AutoAugment used in DeiT
177
+ --use_prefetcher: whether we use prefetcher which can accerelate the training speed
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ global_crops_scale,
183
+ local_crops_scale,
184
+ local_crops_number,
185
+ prob=0.5,
186
+ vanilla_weak_augmentation=False,
187
+ color_aug=False,
188
+ local_crop_size=[96],
189
+ timm_auto_augment_par="rand-m9-mstd0.5-inc1",
190
+ strong_ratio=0.45,
191
+ re_prob=0.25,
192
+ use_prefetcher=False,
193
+ ):
194
+
195
+ ## propability to perform strong augmentation
196
+ self.prob = prob
197
+ ## whether we use the commonly used augmentations, e.g. DINO or MoCo-V3
198
+ self.vanilla_weak_augmentation = vanilla_weak_augmentation
199
+
200
+ flip_and_color_jitter = transforms.Compose(
201
+ [
202
+ transforms.RandomHorizontalFlip(p=0.5),
203
+ transforms.RandomApply(
204
+ [
205
+ transforms.ColorJitter(
206
+ brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
207
+ )
208
+ ],
209
+ p=0.8,
210
+ ),
211
+ transforms.RandomGrayscale(p=0.2),
212
+ ]
213
+ )
214
+
215
+ if use_prefetcher:
216
+ normalize = transforms.Compose(
217
+ [
218
+ transforms.ToTensor(),
219
+ ]
220
+ )
221
+ else:
222
+ normalize = transforms.Compose(
223
+ [
224
+ transforms.ToTensor(),
225
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
226
+ ]
227
+ )
228
+
229
+ ##====== build augmentation of global crops, i.e. 224-sized image crops =========
230
+ # first global crop, always weak augmentation
231
+ self.global_transfo1 = transforms.Compose(
232
+ [
233
+ transforms.RandomResizedCrop(
234
+ 224, scale=global_crops_scale, interpolation=Image.BICUBIC
235
+ ),
236
+ flip_and_color_jitter,
237
+ GaussianBlur(1.0),
238
+ normalize,
239
+ ]
240
+ )
241
+
242
+ # second global crop, always weak augmentation
243
+ self.global_transfo2 = transforms.Compose(
244
+ [
245
+ transforms.RandomResizedCrop(
246
+ 224, scale=global_crops_scale, interpolation=Image.BICUBIC
247
+ ),
248
+ flip_and_color_jitter,
249
+ GaussianBlur(0.1),
250
+ Solarization(0.2),
251
+ normalize,
252
+ ]
253
+ )
254
+
255
+ # strong augmentation, maybe used if we need to perform strong augmentation
256
+ self.global_transfo3 = strong_transforms(
257
+ img_size=224,
258
+ scale=global_crops_scale,
259
+ ratio=(0.75, 1.3333333333333333),
260
+ hflip=0.5,
261
+ vflip=0.0,
262
+ color_jitter=0.4,
263
+ auto_augment=timm_auto_augment_par, # 'rand-m9-mstd0.5-inc1'
264
+ interpolation="random",
265
+ use_prefetcher=use_prefetcher, # True
266
+ mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
267
+ std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
268
+ re_prob=re_prob, # 0.25
269
+ re_mode="pixel",
270
+ re_count=1,
271
+ re_num_splits=0,
272
+ color_aug=color_aug,
273
+ strong_ratio=strong_ratio,
274
+ )
275
+
276
+ ##====== build augmentation of local crops, i.e. 96-sized image crops =========
277
+ self.local_crops_number = (
278
+ local_crops_number # transformation for the local small crops
279
+ )
280
+ assert local_crop_size[0] == 96
281
+ # weak augmentation, maybe used if we need to perform weak augmentation
282
+ self.local_transfo = transforms.Compose(
283
+ [
284
+ transforms.RandomResizedCrop(
285
+ local_crop_size[0],
286
+ scale=local_crops_scale,
287
+ interpolation=Image.BICUBIC,
288
+ ),
289
+ flip_and_color_jitter,
290
+ GaussianBlur(p=0.5),
291
+ normalize,
292
+ ]
293
+ )
294
+ # strong augmentation, maybe used if we need to perform strong augmentation
295
+ self.local_transfo2 = strong_transforms(
296
+ img_size=local_crop_size[0], # (224, 224)
297
+ scale=local_crops_scale, # (0.08, 1.0)
298
+ ratio=(0.75, 1.3333333333333333), # (0.75, 1.3333333333333333)
299
+ hflip=0.5, # 0.5
300
+ vflip=0.0, # 0.0
301
+ color_jitter=0.4, # 0.4
302
+ auto_augment=timm_auto_augment_par, # 'rand-m9-mstd0.5-inc1'
303
+ interpolation="random", # 'random'
304
+ use_prefetcher=use_prefetcher, # True
305
+ mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
306
+ std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
307
+ re_prob=re_prob, # 0.25
308
+ re_mode="pixel", # 'pixel'
309
+ re_count=1, # 1
310
+ re_num_splits=0, # 0
311
+ color_aug=color_aug,
312
+ strong_ratio=strong_ratio,
313
+ )
314
+
315
+ def __call__(self, image):
316
+ """
317
+ implement multi-crop data augmentation. Generate two 224-sized +
318
+ "local_crops_number" 96-sized images
319
+ """
320
+ crops = []
321
+ ##====== images to be fed into teacher, two 224-sized =========
322
+ img1 = self.global_transfo1(image)
323
+ img2 = self.global_transfo2(image)
324
+ crops.append(img1)
325
+ crops.append(img2)
326
+
327
+ ##====== images to be fed into student, two 224-sized + "local_crops_number" 96-sized =========
328
+ # first to generate two 224-sized
329
+ # this weak_flag indicates whether the current image is weakly augmented.
330
+ # For local group supervision, we only use weakly augmented images of size 224 to
331
+ # update the memory for local-group aggregation.
332
+ weak_flag = False
333
+
334
+ if self.vanilla_weak_augmentation is True:
335
+ ## directly copy the images of weak augmentation
336
+ crops.append(copy.deepcopy(img1))
337
+ crops.append(copy.deepcopy(img2))
338
+ weak_flag = True
339
+ elif self.prob < 1.0 and random.random() > self.prob:
340
+ ## whether perform strong augmentation
341
+ crops.append(self.global_transfo3(image))
342
+ crops.append(self.global_transfo3(image))
343
+ else:
344
+ ## perform weak augmentation
345
+ crops.append(self.global_transfo1(image))
346
+ crops.append(self.global_transfo2(image))
347
+ weak_flag = True
348
+
349
+ # then to generate "local_crops_number" 96-sized
350
+ for _ in range(self.local_crops_number):
351
+ if self.prob < 1.0 and random.random() > self.prob:
352
+ ## whether perform strong augmentation
353
+ crops.append(self.local_transfo2(image))
354
+ else:
355
+ ## perform weak augmentation
356
+ crops.append(self.local_transfo(image))
357
+
358
+ return crops, weak_flag
359
+
360
+
361
+ def get_dataset(args):
362
+ """
363
+ build a multi-crop data augmentation and a dataset/dataloader
364
+ """
365
+ ## preparing augmentations, including weak and strong augmentations
366
+ transform = DataAugmentation(
367
+ global_crops_scale=args.global_crops_scale,
368
+ local_crops_scale=args.local_crops_scale,
369
+ local_crops_number=args.local_crops_number,
370
+ vanilla_weak_augmentation=args.vanilla_weak_augmentation,
371
+ prob=args.prob,
372
+ color_aug=args.color_aug,
373
+ local_crop_size=args.size_crops,
374
+ timm_auto_augment_par=args.timm_auto_augment_par,
375
+ strong_ratio=args.strong_ratio,
376
+ re_prob=args.re_prob,
377
+ use_prefetcher=args.use_prefetcher,
378
+ )
379
+
380
+ ## For debug mode, we only load the first two classes to reduce data reading time.
381
+ ## otherwise, we load all training data for pretraining.
382
+ class_num = 2 if args.debug else 1000
383
+ dataset = ImageFolder(args.data_path, transform=transform, class_num=class_num)
384
+
385
+ sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
386
+ data_loader = torch.utils.data.DataLoader(
387
+ dataset,
388
+ sampler=sampler,
389
+ batch_size=args.batch_size_per_gpu,
390
+ num_workers=args.num_workers,
391
+ pin_memory=True,
392
+ drop_last=True,
393
+ )
394
+ return data_loader
395
+
396
+
397
+ class data_prefetcher:
398
+ """
399
+ implement data prefetcher. we perform some augmentation on GPUs intead of CPUs
400
+ --loader: a data loader
401
+ --fp16: whether we use fp16, if yes, we need to tranform the data to be fp16
402
+ """
403
+
404
+ def __init__(self, loader, fp16=True):
405
+ self.loader = iter(loader)
406
+ self.fp16 = fp16
407
+ self.stream = torch.cuda.Stream()
408
+ self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1, 3, 1, 1)
409
+ self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1, 3, 1, 1)
410
+ if fp16:
411
+ self.mean = self.mean.half()
412
+ self.std = self.std.half()
413
+
414
+ self.preload()
415
+
416
+ def preload(self):
417
+ """
418
+ preload the next minibatch of data
419
+ """
420
+ try:
421
+ self.multi_crops, self.weak_flag = next(self.loader)
422
+ except StopIteration:
423
+ self.multi_crops, self.weak_flag = None, None
424
+ return
425
+
426
+ with torch.cuda.stream(self.stream):
427
+ for i in range(len(self.multi_crops)):
428
+ self.multi_crops[i] = self.multi_crops[i].cuda(non_blocking=True)
429
+ if self.fp16:
430
+ self.multi_crops[i] = (
431
+ self.multi_crops[i].half().sub_(self.mean).div_(self.std)
432
+ )
433
+ else:
434
+ self.multi_crops[i] = (
435
+ self.multi_crops[i].float().sub_(self.mean).div_(self.std)
436
+ )
437
+
438
+ def next(self):
439
+ """
440
+ load the next minibatch of data
441
+ """
442
+ torch.cuda.current_stream().wait_stream(self.stream)
443
+ multi_crops, weak_flags = self.multi_crops, self.weak_flag
444
+ self.preload()
445
+ return multi_crops, weak_flags
src/optimizer.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ implment some functions for optimizers
16
+ """
17
+ import numpy as np
18
+ import torch
19
+
20
+ import utils
21
+
22
+
23
+ def clip_gradients(model, clip):
24
+ """
25
+ clip gradient if gradient norm > clip
26
+ """
27
+ norms = []
28
+ for name, p in model.named_parameters():
29
+ if p.grad is not None:
30
+ param_norm = p.grad.data.norm(2)
31
+ norms.append(param_norm.item())
32
+ clip_coef = clip / (param_norm + 1e-6)
33
+ if clip_coef < 1:
34
+ p.grad.data.mul_(clip_coef)
35
+ return norms
36
+
37
+
38
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
39
+ """
40
+ cancle gradient if epoch > freeze_last_layer
41
+ """
42
+ if epoch >= freeze_last_layer:
43
+ return
44
+ for n, p in model.named_parameters():
45
+ if "last_layer" in n:
46
+ p.grad = None
47
+
48
+
49
+ def cosine_scheduler(
50
+ base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0
51
+ ):
52
+ """
53
+ start_warmup_value to base_value in the first warmup_epochs epochs;
54
+ then cosine scheduling base_value to final_value in the remaining epochs-warmup_epochs
55
+ """
56
+ warmup_schedule = np.array([])
57
+ warmup_iters = warmup_epochs * niter_per_ep
58
+ if warmup_epochs > 0:
59
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
60
+
61
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
62
+ schedule = final_value + 0.5 * (base_value - final_value) * (
63
+ 1 + np.cos(np.pi * iters / len(iters))
64
+ )
65
+
66
+ schedule = np.concatenate((warmup_schedule, schedule))
67
+ assert len(schedule) == epochs * niter_per_ep
68
+ return schedule
69
+
70
+
71
+ def get_params_groups(model):
72
+ """
73
+ divide the parameters into several groups, see below
74
+ """
75
+ regularized = []
76
+ not_regularized = []
77
+ patch_embed = []
78
+ patch_embed_not_regularized = []
79
+ for name, param in model.named_parameters():
80
+ if not param.requires_grad:
81
+ continue
82
+ # we do not regularize biases nor Norm parameters
83
+ if name.endswith(".bias") or len(param.shape) == 1:
84
+ if "patch_embed" in name:
85
+ patch_embed_not_regularized.append(param)
86
+ else:
87
+ not_regularized.append(param)
88
+ elif "patch_embed" in name:
89
+ patch_embed.append(param)
90
+ else:
91
+ regularized.append(param)
92
+ return [
93
+ {"name": "normal_params", "params": regularized},
94
+ {"name": "patch_embed", "params": patch_embed},
95
+ {
96
+ "name": "no_wd",
97
+ "params": not_regularized,
98
+ "apply_wd": False,
99
+ "weight_decay": 0.0,
100
+ },
101
+ {
102
+ "name": "patch_embed_no_wd",
103
+ "params": patch_embed_not_regularized,
104
+ "apply_wd": False,
105
+ "weight_decay": 0.0,
106
+ },
107
+ ]
108
+
109
+
110
+ class LARS(torch.optim.Optimizer):
111
+ """
112
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ params,
118
+ lr=0,
119
+ weight_decay=0,
120
+ momentum=0.9,
121
+ eta=0.001,
122
+ weight_decay_filter=None,
123
+ lars_adaptation_filter=None,
124
+ ):
125
+ defaults = dict(
126
+ lr=lr,
127
+ weight_decay=weight_decay,
128
+ momentum=momentum,
129
+ eta=eta,
130
+ weight_decay_filter=weight_decay_filter,
131
+ lars_adaptation_filter=lars_adaptation_filter,
132
+ )
133
+ super().__init__(params, defaults)
134
+
135
+ @torch.no_grad()
136
+ def step(self):
137
+ for g in self.param_groups:
138
+ for p in g["params"]:
139
+ dp = p.grad
140
+
141
+ if dp is None:
142
+ continue
143
+
144
+ if p.ndim != 1:
145
+ dp = dp.add(p, alpha=g["weight_decay"])
146
+
147
+ if p.ndim != 1:
148
+ param_norm = torch.norm(p)
149
+ update_norm = torch.norm(dp)
150
+ one = torch.ones_like(param_norm)
151
+ q = torch.where(
152
+ param_norm > 0.0,
153
+ torch.where(
154
+ update_norm > 0, (g["eta"] * param_norm / update_norm), one
155
+ ),
156
+ one,
157
+ )
158
+ dp = dp.mul(q)
159
+
160
+ param_state = self.state[p]
161
+ if "mu" not in param_state:
162
+ param_state["mu"] = torch.zeros_like(p)
163
+ mu = param_state["mu"]
164
+ mu.mul_(g["momentum"]).add_(dp)
165
+
166
+ p.add_(mu, alpha=-g["lr"])
167
+
168
+
169
+ def get_optimizer(student, len_dataloader, args):
170
+ """
171
+ build an optimizer for training
172
+ """
173
+ # ============ preparing optimizer ... ============
174
+ params_groups = get_params_groups(student)
175
+ if args.optimizer == "adamw":
176
+ optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
177
+ elif args.optimizer == "sgd":
178
+ optimizer = torch.optim.SGD(
179
+ params_groups, lr=0, momentum=0.9
180
+ ) # lr is set by scheduler
181
+ elif args.optimizer == "lars":
182
+ optimizer = LARS(params_groups) # to use with convnet and large batches
183
+ # for mixed precision training
184
+ fp16_scaler = None
185
+ if args.use_fp16:
186
+ fp16_scaler = torch.cuda.amp.GradScaler()
187
+
188
+ # ============ init schedulers ... ============
189
+ lr_schedule = cosine_scheduler(
190
+ args.lr
191
+ * (args.batch_size_per_gpu * utils.get_world_size())
192
+ / 256.0, # linear scaling rule
193
+ args.min_lr,
194
+ args.epochs,
195
+ len_dataloader,
196
+ warmup_epochs=args.warmup_epochs,
197
+ )
198
+ wd_schedule = cosine_scheduler(
199
+ args.weight_decay,
200
+ args.weight_decay_end,
201
+ args.epochs,
202
+ len_dataloader, # len(data_loader),
203
+ )
204
+ # momentum parameter is increased to 1. during training with a cosine schedule
205
+ momentum_schedule = cosine_scheduler(
206
+ args.momentum_teacher, 1, args.epochs, len_dataloader
207
+ )
208
+ print("Loss, optimizer and schedulers ready.")
209
+
210
+ return optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule
src/vision_transformer.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ ViT backbones, including ViT-small, ViT-base, ViT-large
16
+ Mostly copy-paste from timm library.
17
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
18
+ """
19
+ import math
20
+ from functools import partial
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+
26
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
27
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
28
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
29
+ def norm_cdf(x):
30
+ # Computes standard normal cumulative distribution function
31
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
32
+
33
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
34
+ warnings.warn(
35
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
36
+ "The distribution of values may be incorrect.",
37
+ stacklevel=2,
38
+ )
39
+
40
+ with torch.no_grad():
41
+ # Values are generated by using a truncated uniform distribution and
42
+ # then using the inverse CDF for the normal distribution.
43
+ # Get upper and lower cdf values
44
+ lower = norm_cdf((a - mean) / std)
45
+ upper = norm_cdf((b - mean) / std)
46
+
47
+ # Uniformly fill tensor with values from [l, u], then translate to
48
+ # [2l-1, 2u-1].
49
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
50
+
51
+ # Use inverse cdf transform for normal distribution to get truncated
52
+ # standard normal
53
+ tensor.erfinv_()
54
+
55
+ # Transform to proper mean, std
56
+ tensor.mul_(std * math.sqrt(2.0))
57
+ tensor.add_(mean)
58
+
59
+ # Clamp to ensure it's in the proper range
60
+ tensor.clamp_(min=a, max=b)
61
+ return tensor
62
+
63
+
64
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
65
+ # type: (torch.tensor, float, float, float, float) -> torch.tensor
66
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
67
+
68
+
69
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
70
+ """
71
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
72
+ """
73
+ if drop_prob == 0.0 or not training:
74
+ return x
75
+ keep_prob = 1 - drop_prob
76
+ shape = (x.shape[0],) + (1,) * (
77
+ x.ndim - 1
78
+ ) # work with diff dim tensors, not just 2D ConvNets
79
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80
+ random_tensor.floor_() # binarize
81
+ output = x.div(keep_prob) * random_tensor
82
+ return output
83
+
84
+
85
+ class DropPath(nn.Module):
86
+ """
87
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
88
+ """
89
+
90
+ def __init__(self, drop_prob=None):
91
+ super(DropPath, self).__init__()
92
+ self.drop_prob = drop_prob
93
+
94
+ def forward(self, x):
95
+ return drop_path(x, self.drop_prob, self.training)
96
+
97
+
98
+ class Mlp(nn.Module):
99
+ """
100
+ MLP module in ViT
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ in_features,
106
+ hidden_features=None,
107
+ out_features=None,
108
+ act_layer=nn.GELU,
109
+ drop=0.0,
110
+ ):
111
+ super().__init__()
112
+ out_features = out_features or in_features
113
+ hidden_features = hidden_features or in_features
114
+ self.fc1 = nn.Linear(in_features, hidden_features)
115
+ self.act = act_layer()
116
+ self.fc2 = nn.Linear(hidden_features, out_features)
117
+ self.drop = nn.Dropout(drop)
118
+
119
+ def forward(self, x):
120
+ x = self.fc1(x)
121
+ x = self.act(x)
122
+ x = self.drop(x)
123
+ x = self.fc2(x)
124
+ x = self.drop(x)
125
+ return x
126
+
127
+
128
+ class Attention(nn.Module):
129
+ """
130
+ Attention module in ViT
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ dim,
136
+ num_heads=8,
137
+ qkv_bias=False,
138
+ qk_scale=None,
139
+ attn_drop=0.0,
140
+ proj_drop=0.0,
141
+ ):
142
+ super().__init__()
143
+ self.num_heads = num_heads
144
+ head_dim = dim // num_heads
145
+ self.scale = qk_scale or head_dim ** -0.5
146
+
147
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
148
+ self.attn_drop = nn.Dropout(attn_drop)
149
+ self.proj = nn.Linear(dim, dim)
150
+ self.proj_drop = nn.Dropout(proj_drop)
151
+
152
+ def forward(self, x):
153
+ B, N, C = x.shape
154
+
155
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
156
+ reshaped_qkv = qkv.permute(2, 0, 3, 1, 4)
157
+ q, k, v = reshaped_qkv[0], reshaped_qkv[1], reshaped_qkv[2]
158
+
159
+ attn = (q @ k.transpose(-2, -1)) * self.scale
160
+ attn = attn.softmax(dim=-1)
161
+ attn = self.attn_drop(attn)
162
+
163
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
164
+ x = self.proj(x)
165
+ x = self.proj_drop(x)
166
+ return x, attn
167
+
168
+
169
+ class Block(nn.Module):
170
+ """
171
+ ViT block, including Attention, MLP, etc.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ dim,
177
+ num_heads,
178
+ mlp_ratio=4.0,
179
+ qkv_bias=False,
180
+ qk_scale=None,
181
+ drop=0.0,
182
+ attn_drop=0.0,
183
+ drop_path=0.0,
184
+ act_layer=nn.GELU,
185
+ norm_layer=nn.LayerNorm,
186
+ ):
187
+ super().__init__()
188
+ self.norm1 = norm_layer(dim)
189
+ self.attn = Attention(
190
+ dim,
191
+ num_heads=num_heads,
192
+ qkv_bias=qkv_bias,
193
+ qk_scale=qk_scale,
194
+ attn_drop=attn_drop,
195
+ proj_drop=drop,
196
+ )
197
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
198
+ self.norm2 = norm_layer(dim)
199
+ mlp_hidden_dim = int(dim * mlp_ratio)
200
+ self.mlp = Mlp(
201
+ in_features=dim,
202
+ hidden_features=mlp_hidden_dim,
203
+ act_layer=act_layer,
204
+ drop=drop,
205
+ )
206
+
207
+ def forward(self, x, return_attention=False):
208
+ y, attn = self.attn(self.norm1(x))
209
+ if return_attention:
210
+ return attn
211
+ x = x + self.drop_path(y)
212
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
213
+ return x
214
+
215
+
216
+ class PatchEmbed(nn.Module):
217
+ """Image to Patch Embedding"""
218
+
219
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
220
+ super().__init__()
221
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
222
+ self.img_size = img_size
223
+ self.patch_size = patch_size
224
+ self.num_patches = num_patches
225
+
226
+ self.proj = nn.Conv2d(
227
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
228
+ )
229
+
230
+ def forward(self, x):
231
+ x = self.proj(x).flatten(2).transpose(1, 2)
232
+ return x
233
+
234
+
235
+ class VisionTransformer(nn.Module):
236
+ """Vision Transformer"""
237
+
238
+ def __init__(
239
+ self,
240
+ img_size=[224, 224],
241
+ patch_size=16,
242
+ in_chans=3,
243
+ num_classes=0,
244
+ embed_dim=768,
245
+ depth=12,
246
+ num_heads=12,
247
+ mlp_ratio=4.0,
248
+ qkv_bias=False,
249
+ qk_scale=None,
250
+ drop_rate=0.0,
251
+ attn_drop_rate=0.0,
252
+ drop_path_rate=0.0,
253
+ norm_layer=nn.LayerNorm,
254
+ num_relation_blocks=0,
255
+ **kwargs
256
+ ):
257
+ super().__init__()
258
+ self.num_features = self.embed_dim = embed_dim
259
+ self.patch_size = patch_size
260
+ self.num_classes = num_classes
261
+ self.depth = depth
262
+
263
+ self.patch_embed = PatchEmbed(
264
+ img_size=img_size[0],
265
+ patch_size=patch_size,
266
+ in_chans=in_chans,
267
+ embed_dim=embed_dim,
268
+ )
269
+
270
+ num_patches = self.patch_embed.num_patches
271
+
272
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
273
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
274
+ trunc_normal_(self.pos_embed, std=0.02)
275
+
276
+ self.pos_drop = nn.Dropout(p=drop_rate)
277
+
278
+ dpr = [
279
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
280
+ ] # stochastic depth decay rule
281
+ self.blocks = nn.ModuleList(
282
+ [
283
+ Block(
284
+ dim=embed_dim,
285
+ num_heads=num_heads,
286
+ mlp_ratio=mlp_ratio,
287
+ qkv_bias=qkv_bias,
288
+ qk_scale=qk_scale,
289
+ drop=drop_rate,
290
+ attn_drop=attn_drop_rate,
291
+ drop_path=dpr[i],
292
+ norm_layer=norm_layer,
293
+ )
294
+ for i in range(depth)
295
+ ]
296
+ )
297
+ self.norm = norm_layer(embed_dim)
298
+
299
+ # Classifier head
300
+ self.head = (
301
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
302
+ )
303
+
304
+ self.num_relation_blocks = num_relation_blocks
305
+ if num_relation_blocks > 0:
306
+ self.relation_blocks = nn.ModuleList(
307
+ [
308
+ Block(
309
+ dim=embed_dim,
310
+ num_heads=num_heads,
311
+ mlp_ratio=mlp_ratio,
312
+ qkv_bias=qkv_bias,
313
+ qk_scale=qk_scale,
314
+ drop=drop_rate,
315
+ attn_drop=attn_drop_rate,
316
+ drop_path=dpr[i],
317
+ norm_layer=norm_layer,
318
+ )
319
+ for i in range(int(num_relation_blocks))
320
+ ]
321
+ )
322
+
323
+ trunc_normal_(self.cls_token, std=0.02)
324
+ self.apply(self._init_weights)
325
+
326
+ def add_pos_emb_for_cls_token(self):
327
+ pe_cls_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
328
+ self.pos_embed = nn.Parameter(torch.cat([pe_cls_token, self.pos_embed], dim=1))
329
+ self.pos_embed.requires_grad = False
330
+
331
+ def _init_weights(self, m):
332
+ if isinstance(m, nn.Linear):
333
+ trunc_normal_(m.weight, std=0.02)
334
+ if isinstance(m, nn.Linear) and m.bias is not None:
335
+ nn.init.constant_(m.bias, 0)
336
+ elif isinstance(m, nn.LayerNorm):
337
+ nn.init.constant_(m.bias, 0)
338
+ nn.init.constant_(m.weight, 1.0)
339
+
340
+ def interpolate_pos_encoding(self, x, w, h):
341
+ npatch = x.shape[1] - 1
342
+ N = self.pos_embed.shape[1] - 1
343
+ if npatch == N and w == h:
344
+ return self.pos_embed
345
+ class_pos_embed = self.pos_embed[:, 0]
346
+ patch_pos_embed = self.pos_embed[:, 1:]
347
+ dim = x.shape[-1]
348
+ w0 = w // self.patch_embed.patch_size
349
+ h0 = h // self.patch_embed.patch_size
350
+ # we add a small number to avoid floating point error in the interpolation
351
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
352
+ w0, h0 = w0 + 0.1, h0 + 0.1
353
+ patch_pos_embed = nn.functional.interpolate(
354
+ patch_pos_embed.reshape(
355
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
356
+ ).permute(0, 3, 1, 2),
357
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
358
+ mode="bicubic",
359
+ )
360
+ assert (
361
+ int(w0) == patch_pos_embed.shape[-2]
362
+ and int(h0) == patch_pos_embed.shape[-1]
363
+ )
364
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
365
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
366
+
367
+ def prepare_tokens(self, x):
368
+ B, nc, w, h = x.shape
369
+ x = self.patch_embed(x) # patch linear embedding
370
+
371
+ # add the [CLS] token to the embed patch tokens
372
+ cls_tokens = self.cls_token.expand(B, -1, -1)
373
+ x = torch.cat((cls_tokens, x), dim=1)
374
+
375
+ # add positional encoding to each token
376
+ x = x + self.interpolate_pos_encoding(x, w, h)
377
+ return self.pos_drop(x)
378
+
379
+ def forward(self, x, return_all=False, local_group_memory_inputs=None, **kwargs):
380
+ x = self.prepare_tokens(x)
381
+ for blk in self.blocks:
382
+ x = blk(x)
383
+
384
+ if self.num_relation_blocks > 0:
385
+ mem = local_group_memory_inputs.get("mem")
386
+ if mem is not None:
387
+ m, _ = mem(x.mean(1))
388
+ rx = torch.cat((x.mean(1).unsqueeze(1), m), dim=1)
389
+ else:
390
+ rx = x
391
+ for i, blk in enumerate(self.relation_blocks):
392
+ rx = blk(rx)
393
+ relation_out = self.norm(rx[:, 0])
394
+
395
+ x = self.norm(x)
396
+ if self.num_classes > 0:
397
+ return self.head(x[:, 0])
398
+
399
+ if return_all:
400
+ return x, relation_out
401
+ else:
402
+ return x[:, 0], relation_out
403
+
404
+ def forward_knn(self, x):
405
+ x = self.prepare_tokens(x)
406
+ for blk in self.blocks:
407
+ x = blk(x)
408
+ x = self.norm(x)
409
+ return x[:, 0]
410
+
411
+ def get_last_selfattention(self, x):
412
+ x = self.prepare_tokens(x)
413
+ for i, blk in enumerate(self.blocks):
414
+ if i < len(self.blocks) - 1:
415
+ x = blk(x)
416
+ else:
417
+ # return attention of the last block
418
+ return blk(x, return_attention=True)
419
+
420
+ def get_intermediate_layers(self, x, n=1):
421
+ x = self.prepare_tokens(x)
422
+ # we return the output tokens from the `n` last blocks
423
+ output = []
424
+ for i, blk in enumerate(self.blocks):
425
+ x = blk(x)
426
+ if len(self.blocks) - i <= n:
427
+ output.append(self.norm(x))
428
+ return output
429
+
430
+ def get_num_layers(self):
431
+ return self.depth
432
+
433
+ @torch.jit.ignore
434
+ def no_weight_decay(self):
435
+ return {"pos_embed", "cls_token"}
436
+
437
+
438
+ def vit_tiny(patch_size=16, **kwargs):
439
+ model = VisionTransformer(
440
+ patch_size=patch_size,
441
+ embed_dim=192,
442
+ depth=12,
443
+ num_heads=3,
444
+ mlp_ratio=4,
445
+ qkv_bias=True,
446
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
447
+ **kwargs
448
+ )
449
+ return model
450
+
451
+
452
+ def vit_small(patch_size=16, **kwargs):
453
+ model = VisionTransformer(
454
+ patch_size=patch_size,
455
+ embed_dim=384,
456
+ depth=12,
457
+ num_heads=6,
458
+ mlp_ratio=4,
459
+ qkv_bias=True,
460
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
461
+ **kwargs
462
+ )
463
+ return model
464
+
465
+
466
+ def vit_base(patch_size=16, **kwargs):
467
+ model = VisionTransformer(
468
+ patch_size=patch_size,
469
+ embed_dim=768,
470
+ depth=12,
471
+ num_heads=12,
472
+ mlp_ratio=4,
473
+ qkv_bias=True,
474
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
475
+ **kwargs
476
+ )
477
+ return model
478
+
479
+
480
+ def vit_large(patch_size=16, **kwargs):
481
+ model = VisionTransformer(
482
+ patch_size=patch_size,
483
+ embed_dim=1024,
484
+ depth=24,
485
+ num_heads=16,
486
+ mlp_ratio=4,
487
+ qkv_bias=True,
488
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
489
+ **kwargs
490
+ )
491
+ return model
utils.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Garena Online Private Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Misc functions.
16
+
17
+ Mostly copy-paste from torchvision references or other public repos like DETR and DINO:
18
+ https://github.com/facebookresearch/detr/blob/master/util/misc.py
19
+ https://github.com/facebookresearch/dino/blob/main/utils.py
20
+ """
21
+ import datetime
22
+ import logging
23
+ import os
24
+ import subprocess
25
+ import sys
26
+ import time
27
+ from collections import defaultdict, deque
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.distributed as dist
32
+ from torch import nn
33
+
34
+
35
+ def get_logger(file_path_name):
36
+ """
37
+ build a logger which both write on the desk and also on the terminal
38
+ """
39
+ logger = logging.getLogger()
40
+ logger.setLevel("INFO")
41
+ BASIC_FORMAT = "%(levelname)s:%(message)s"
42
+ DATE_FORMAT = ""
43
+ formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
44
+ chlr = logging.StreamHandler()
45
+ chlr.setFormatter(formatter)
46
+ chlr.setLevel("INFO")
47
+ fhlr = logging.FileHandler(file_path_name)
48
+ fhlr.setFormatter(formatter)
49
+ logger.addHandler(chlr)
50
+ logger.addHandler(fhlr)
51
+
52
+ return logger
53
+
54
+
55
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
56
+ """
57
+ Re-start from checkpoint
58
+ """
59
+ if not os.path.isfile(ckp_path):
60
+ return
61
+ print("Found checkpoint at {}".format(ckp_path))
62
+
63
+ # open checkpoint file
64
+ checkpoint = torch.load(ckp_path, map_location="cpu")
65
+ # key is what to look for in the checkpoint file
66
+ # value is the object to load
67
+ # example: {'state_dict': model}
68
+ for key, value in kwargs.items():
69
+ if key in checkpoint and value is not None:
70
+ try:
71
+ msg = value.load_state_dict(checkpoint[key], strict=False)
72
+ print(
73
+ "=> loaded '{}' from checkpoint '{}' with msg {}".format(
74
+ key, ckp_path, msg
75
+ )
76
+ )
77
+ except TypeError:
78
+ try:
79
+ msg = value.load_state_dict(checkpoint[key])
80
+ print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
81
+ except ValueError:
82
+ print(
83
+ "=> failed to load '{}' from checkpoint: '{}'".format(
84
+ key, ckp_path
85
+ )
86
+ )
87
+ else:
88
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
89
+
90
+ # reload variable important for the run
91
+ if run_variables is not None:
92
+ for var_name in run_variables:
93
+ if var_name in checkpoint:
94
+ run_variables[var_name] = checkpoint[var_name]
95
+
96
+
97
+ def bool_flag(s):
98
+ """
99
+ Parse boolean arguments from the command line.
100
+ """
101
+ FALSY_STRINGS = {"off", "false", "0"}
102
+ TRUTHY_STRINGS = {"on", "true", "1"}
103
+ if s.lower() in FALSY_STRINGS:
104
+ return False
105
+ elif s.lower() in TRUTHY_STRINGS:
106
+ return True
107
+ else:
108
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
109
+
110
+
111
+ def fix_random_seeds(seed=31):
112
+ """
113
+ Fix random seeds.
114
+ """
115
+ torch.manual_seed(seed)
116
+ torch.cuda.manual_seed_all(seed)
117
+ np.random.seed(seed)
118
+
119
+
120
+ def has_batchnorms(model):
121
+ """
122
+ judge whether a model has batch normalization
123
+ """
124
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
125
+ for name, module in model.named_modules():
126
+ if isinstance(module, bn_types):
127
+ return True
128
+ return False
129
+
130
+
131
+ class SmoothedValue(object):
132
+ """Track a series of values and provide access to smoothed values over a
133
+ window or the global series average.
134
+ """
135
+
136
+ def __init__(self, window_size=20, fmt=None):
137
+ if fmt is None:
138
+ fmt = "{median:.6f} ({global_avg:.6f})"
139
+ self.deque = deque(maxlen=window_size)
140
+ self.total = 0.0
141
+ self.count = 0
142
+ self.fmt = fmt
143
+
144
+ def update(self, value, n=1):
145
+ self.deque.append(value)
146
+ self.count += n
147
+ self.total += value * n
148
+
149
+ def synchronize_between_processes(self):
150
+ """
151
+ Warning: does not synchronize the deque!
152
+ """
153
+ if not is_dist_avail_and_initialized():
154
+ return
155
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
156
+ dist.barrier()
157
+ dist.all_reduce(t)
158
+ t = t.tolist()
159
+ self.count = int(t[0])
160
+ self.total = t[1]
161
+
162
+ @property
163
+ def median(self):
164
+ d = torch.tensor(list(self.deque))
165
+ return d.median().item()
166
+
167
+ @property
168
+ def avg(self):
169
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
170
+ return d.mean().item()
171
+
172
+ @property
173
+ def global_avg(self):
174
+ return self.total / self.count
175
+
176
+ @property
177
+ def max(self):
178
+ return max(self.deque)
179
+
180
+ @property
181
+ def value(self):
182
+ return self.deque[-1]
183
+
184
+ def __str__(self):
185
+ return self.fmt.format(
186
+ median=self.median,
187
+ avg=self.avg,
188
+ global_avg=self.global_avg,
189
+ max=self.max,
190
+ value=self.value,
191
+ )
192
+
193
+
194
+ class MetricLogger(object):
195
+ """
196
+ build a Metric Logger
197
+ """
198
+
199
+ def __init__(self, delimiter="\t"):
200
+ self.meters = defaultdict(SmoothedValue)
201
+ self.delimiter = delimiter
202
+
203
+ def update(self, **kwargs):
204
+ for k, v in kwargs.items():
205
+ if isinstance(v, torch.Tensor):
206
+ v = v.item()
207
+ assert isinstance(v, (float, int))
208
+ self.meters[k].update(v)
209
+
210
+ def __getattr__(self, attr):
211
+ if attr in self.meters:
212
+ return self.meters[attr]
213
+ if attr in self.__dict__:
214
+ return self.__dict__[attr]
215
+ raise AttributeError(
216
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
217
+ )
218
+
219
+ def __str__(self):
220
+ loss_str = []
221
+ for name, meter in self.meters.items():
222
+ loss_str.append("{}: {}".format(name, str(meter)))
223
+ return self.delimiter.join(loss_str)
224
+
225
+ def synchronize_between_processes(self):
226
+ for meter in self.meters.values():
227
+ meter.synchronize_between_processes()
228
+
229
+ def add_meter(self, name, meter):
230
+ self.meters[name] = meter
231
+
232
+ def log_every(self, iterable, print_freq, header=None):
233
+ i = 0
234
+ if not header:
235
+ header = ""
236
+ start_time = time.time()
237
+ end = time.time()
238
+ iter_time = SmoothedValue(fmt="{avg:.6f}")
239
+ data_time = SmoothedValue(fmt="{avg:.6f}")
240
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
241
+ if torch.cuda.is_available():
242
+ log_msg = self.delimiter.join(
243
+ [
244
+ header,
245
+ "[{0" + space_fmt + "}/{1}]",
246
+ "eta: {eta}",
247
+ "{meters}",
248
+ "time: {time}",
249
+ "data: {data}",
250
+ "max mem: {memory:.0f}",
251
+ ]
252
+ )
253
+ else:
254
+ log_msg = self.delimiter.join(
255
+ [
256
+ header,
257
+ "[{0" + space_fmt + "}/{1}]",
258
+ "eta: {eta}",
259
+ "{meters}",
260
+ "time: {time}",
261
+ "data: {data}",
262
+ ]
263
+ )
264
+ MB = 1024.0 * 1024.0
265
+ for obj in iterable:
266
+ data_time.update(time.time() - end)
267
+ yield obj
268
+ iter_time.update(time.time() - end)
269
+ if i % print_freq == 0 or i == len(iterable) - 1:
270
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
271
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
272
+ if torch.cuda.is_available():
273
+ print(
274
+ log_msg.format(
275
+ i,
276
+ len(iterable),
277
+ eta=eta_string,
278
+ meters=str(self),
279
+ time=str(iter_time),
280
+ data=str(data_time),
281
+ memory=torch.cuda.max_memory_allocated() / MB,
282
+ )
283
+ )
284
+ else:
285
+ print(
286
+ log_msg.format(
287
+ i,
288
+ len(iterable),
289
+ eta=eta_string,
290
+ meters=str(self),
291
+ time=str(iter_time),
292
+ data=str(data_time),
293
+ )
294
+ )
295
+ i += 1
296
+ end = time.time()
297
+ total_time = time.time() - start_time
298
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
299
+ print(
300
+ "{} Total time: {} ({:.6f} s / it)".format(
301
+ header, total_time_str, total_time / len(iterable)
302
+ )
303
+ )
304
+
305
+
306
+ def get_sha():
307
+ cwd = os.path.dirname(os.path.abspath(__file__))
308
+
309
+ def _run(command):
310
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
311
+
312
+ sha = "N/A"
313
+ diff = "clean"
314
+ branch = "N/A"
315
+ try:
316
+ sha = _run(["git", "rev-parse", "HEAD"])
317
+ subprocess.check_output(["git", "diff"], cwd=cwd)
318
+ diff = _run(["git", "diff-index", "HEAD"])
319
+ diff = "has uncommited changes" if diff else "clean"
320
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
321
+ except Exception:
322
+ pass
323
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
324
+ return message
325
+
326
+
327
+ def is_dist_avail_and_initialized():
328
+ """
329
+ judge whether distributed training is available and well-initialized
330
+ """
331
+ if not dist.is_available():
332
+ return False
333
+ if not dist.is_initialized():
334
+ return False
335
+ return True
336
+
337
+
338
+ def get_world_size():
339
+ """
340
+ get the world size
341
+ """
342
+ if not is_dist_avail_and_initialized():
343
+ return 1
344
+ return dist.get_world_size()
345
+
346
+
347
+ def get_rank():
348
+ """
349
+ get the rank
350
+ """
351
+ if not is_dist_avail_and_initialized():
352
+ return 0
353
+ return dist.get_rank()
354
+
355
+
356
+ def is_main_process():
357
+ """
358
+ judge whether the current node is the master node
359
+ """
360
+ return get_rank() == 0
361
+
362
+
363
+ def save_on_master(*args, **kwargs):
364
+ """
365
+ save checkpoint on the master node
366
+ """
367
+ if is_main_process():
368
+ torch.save(*args, **kwargs)
369
+
370
+
371
+ def setup_for_distributed(is_master):
372
+ """
373
+ This function disables printing when not in master process
374
+ """
375
+ import builtins as __builtin__
376
+
377
+ builtin_print = __builtin__.print
378
+
379
+ def print(*args, **kwargs):
380
+ force = kwargs.pop("force", False)
381
+ if is_master or force:
382
+ builtin_print(*args, **kwargs)
383
+
384
+ __builtin__.print = print
385
+
386
+
387
+ def init_distributed_ddpjob(args=None):
388
+ """
389
+ initialize the ddp job
390
+ """
391
+ if dist.is_available() and dist.is_initialized():
392
+ return dist.get_world_size(), dist.get_rank()
393
+
394
+ try:
395
+ os.environ["MASTER_PORT"] = "40101"
396
+ torch.distributed.init_process_group(backend="nccl")
397
+ except Exception:
398
+ world_size, rank = 1, 0
399
+ print("distributed training not available")
400
+
401
+ world_size = dist.get_world_size()
402
+ rank = dist.get_rank()
403
+ args.gpu = args.rank
404
+ args.world_size, args.rank = world_size, rank
405
+ return world_size, rank
406
+
407
+
408
+ def init_distributed_mode(args):
409
+ """
410
+ initialize the normal job
411
+ """
412
+ # launched with torch.distributed.launch
413
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
414
+ args.rank = int(os.environ["RANK"])
415
+ args.world_size = int(os.environ["WORLD_SIZE"])
416
+ args.gpu = int(os.environ.get("LOCAL_RANK", 0))
417
+ print(
418
+ "args.rank",
419
+ args.rank,
420
+ "args.world_size",
421
+ args.world_size,
422
+ "args.gpu",
423
+ args.gpu,
424
+ )
425
+ print("get_rank()", get_rank())
426
+ # launched with submitit on a slurm cluster
427
+ elif "SLURM_PROCID" in os.environ:
428
+ args.rank = int(os.environ["SLURM_PROCID"])
429
+ args.gpu = args.rank % torch.cuda.device_count()
430
+ # launched naively with `python main_dino.py`
431
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
432
+ elif torch.cuda.is_available():
433
+ print("Will run the code on one GPU.")
434
+ args.rank, args.gpu, args.world_size = 0, 0, 1
435
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
436
+ os.environ["MASTER_PORT"] = "2950"
437
+ else:
438
+ print("Does not support training without GPU.")
439
+ sys.exit(1)
440
+
441
+ os.environ["MASTER_PORT"] = "6542"
442
+
443
+ dist.init_process_group(
444
+ backend="nccl",
445
+ init_method=args.dist_url,
446
+ world_size=args.world_size,
447
+ rank=args.rank,
448
+ )
449
+
450
+ torch.cuda.set_device(args.gpu)
451
+ print(
452
+ "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
453
+ )
454
+ dist.barrier()
455
+ setup_for_distributed(args.rank == 0)
456
+
457
+
458
+ def accuracy(output, target, topk=(1,)):
459
+ """
460
+ Computes the accuracy over the k top predictions for the specified values of k
461
+ """
462
+ maxk = max(topk)
463
+ batch_size = target.size(0)
464
+ _, pred = output.topk(maxk, 1, True, True)
465
+ pred = pred.t()
466
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
467
+ return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk]
468
+
469
+
470
+ def multi_scale(samples, model):
471
+ """
472
+ build a multi-scale features
473
+ """
474
+ v = None
475
+ for s in [1, 1 / 2 ** (1 / 2), 1 / 2]: # we use 3 different scales
476
+ if s == 1:
477
+ inp = samples.clone()
478
+ else:
479
+ inp = nn.functional.interpolate(
480
+ samples, scale_factor=s, mode="bilinear", align_corners=False
481
+ )
482
+ feats = model.forward_knn(inp).clone()
483
+ if v is None:
484
+ v = feats
485
+ else:
486
+ v += feats
487
+ v /= 3
488
+ v /= v.norm()
489
+ return v
490
+
491
+
492
+ class AllGather(torch.autograd.Function):
493
+ """
494
+ gather the variable on different nodes toghther
495
+ """
496
+
497
+ @staticmethod
498
+ def forward(ctx, x):
499
+ if (
500
+ dist.is_available()
501
+ and dist.is_initialized()
502
+ and (dist.get_world_size() > 1)
503
+ ):
504
+ outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
505
+ dist.all_gather(outputs, x)
506
+ return torch.cat(outputs, 0)
507
+ return x
508
+
509
+ @staticmethod
510
+ def backward(ctx, grads):
511
+ if (
512
+ dist.is_available()
513
+ and dist.is_initialized()
514
+ and (dist.get_world_size() > 1)
515
+ ):
516
+ s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
517
+ e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
518
+ grads = grads.contiguous()
519
+ dist.all_reduce(grads)
520
+ return grads[s:e]
521
+ return grads
522
+
523
+
524
+ class AllReduce(torch.autograd.Function):
525
+ """
526
+ reduce the variable on different nodes toghther
527
+ """
528
+
529
+ @staticmethod
530
+ def forward(ctx, x):
531
+ if (
532
+ dist.is_available()
533
+ and dist.is_initialized()
534
+ and (dist.get_world_size() > 1)
535
+ ):
536
+ x = x.contiguous() / dist.get_world_size()
537
+ dist.all_reduce(x)
538
+ return x
539
+
540
+ @staticmethod
541
+ def backward(ctx, grads):
542
+ return grads
543
+
544
+
545
+ def load_pretrained_weights(
546
+ model, pretrained_weights, checkpoint_key, model_name, patch_size
547
+ ):
548
+ if os.path.isfile(pretrained_weights):
549
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
550
+ if checkpoint_key is not None and checkpoint_key in state_dict:
551
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
552
+ state_dict = state_dict[checkpoint_key]
553
+ # remove `module.` prefix
554
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
555
+ # remove `backbone.` prefix induced by multicrop wrapper
556
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
557
+ # remove `encoder.` prefix induced by MAE
558
+ state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items()}
559
+ msg = model.load_state_dict(state_dict, strict=False)
560
+ print(
561
+ "Pretrained weights found at {} and loaded with msg: {}".format(
562
+ pretrained_weights, msg
563
+ )
564
+ )
565
+ else:
566
+ print(
567
+ "There is no reference weights available for this model => We use random weights."
568
+ )
569
+
570
+
571
+ @torch.no_grad()
572
+ def concat_all_gather(tensor):
573
+ """
574
+ Performs all_gather operation on the provided tensors.
575
+ *** Warning ***: torch.distributed.all_gather has no gradient.
576
+ """
577
+ tensors_gather = [
578
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
579
+ ]
580
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
581
+
582
+ output = torch.cat(tensors_gather, dim=0)
583
+ return output