hasibzunair commited on
Commit
2895c00
1 Parent(s): c16c4c8
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
dino/README.md ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :new: *Please check out our more recent [DINOv2](https://github.com/facebookresearch/dinov2) effort in the same line of work.*
2
+
3
+ # Self-Supervised Vision Transformers with DINO
4
+
5
+ PyTorch implementation and pretrained models for DINO. For details, see **Emerging Properties in Self-Supervised Vision Transformers**.
6
+ [[`blogpost`](https://ai.facebook.com/blog/dino-paws-computer-vision-with-self-supervised-transformers-and-10x-more-efficient-training)] [[`arXiv`](https://arxiv.org/abs/2104.14294)] [[`Yannic Kilcher's video`](https://www.youtube.com/watch?v=h3ij3F3cPIk)]
7
+
8
+ <div align="center">
9
+ <img width="100%" alt="DINO illustration" src=".github/dino.gif">
10
+ </div>
11
+
12
+ ## Pretrained models
13
+ You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the backbone in `onnx` format, as well as detailed arguments and training/evaluation logs. Note that `DeiT-S` and `ViT-S` names refer exactly to the same architecture.
14
+
15
+ <table>
16
+ <tr>
17
+ <th>arch</th>
18
+ <th>params</th>
19
+ <th>k-nn</th>
20
+ <th>linear</th>
21
+ <th colspan="6">download</th>
22
+ </tr>
23
+ <tr>
24
+ <td>ViT-S/16</td>
25
+ <td>21M</td>
26
+ <td>74.5%</td>
27
+ <td>77.0%</td>
28
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth">backbone only</a></td>
29
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_full_checkpoint.pth">full ckpt</a></td>
30
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deits16.onnx">onnx</a></td>
31
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/args.txt">args</a></td>
32
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_log.txt">logs</a></td>
33
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_eval_linear_log.txt">eval logs</a></td>
34
+ </tr>
35
+ <tr>
36
+ <td>ViT-S/8</td>
37
+ <td>21M</td>
38
+ <td>78.3%</td>
39
+ <td>79.7%</td>
40
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth">backbone only</a></td>
41
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_full_checkpoint.pth">full ckpt</a></td>
42
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deits8.onnx">onnx</a></td>
43
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/args.txt">args</a></td>
44
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_log.txt">logs</a></td>
45
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_eval_linear_log.txt">eval logs</a></td>
46
+ </tr>
47
+ <tr>
48
+ <td>ViT-B/16</td>
49
+ <td>85M</td>
50
+ <td>76.1%</td>
51
+ <td>78.2%</td>
52
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth">backbone only</a></td>
53
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth">full ckpt</a></td>
54
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitb16.onnx">onnx</a></td>
55
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/args.txt">args</a></td>
56
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_log.txt">logs</a></td>
57
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_eval_linear_log.txt">eval logs</a></td>
58
+ </tr>
59
+ <tr>
60
+ <td>ViT-B/8</td>
61
+ <td>85M</td>
62
+ <td>77.4%</td>
63
+ <td>80.1%</td>
64
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth">backbone only</a></td>
65
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_full_checkpoint.pth">full ckpt</a></td>
66
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitb8.onnx">onnx</a></td>
67
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/args.txt">args</a></td>
68
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_log.txt">logs</a></td>
69
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_eval_linear_log.txt">eval logs</a></td>
70
+ </tr>
71
+ <tr>
72
+ <td>ResNet-50</td>
73
+ <td>23M</td>
74
+ <td>67.5%</td>
75
+ <td>75.3%</td>
76
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth">backbone only</a></td>
77
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_full_checkpoint.pth">full ckpt</a></td>
78
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50.onnx">onnx</a></td>
79
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/args.txt">args</a></td>
80
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_log.txt">logs</a></td>
81
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_eval_linear_log.txt">eval logs</a></td>
82
+ </tr>
83
+ </table>
84
+
85
+ We also release XCiT models ([[`arXiv`](https://arxiv.org/abs/2106.09681)] [[`code`](https://github.com/facebookresearch/xcit)]) trained with DINO:
86
+ <table>
87
+ <tr>
88
+ <th>arch</th>
89
+ <th>params</th>
90
+ <th>k-nn</th>
91
+ <th>linear</th>
92
+ <th colspan="5">download</th>
93
+ </tr>
94
+ <tr>
95
+ <td>xcit_small_12_p16</td>
96
+ <td>26M</td>
97
+ <td>76.0%</td>
98
+ <td>77.8%</td>
99
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth">backbone only</a></td>
100
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_full_checkpoint.pth">full ckpt</a></td>
101
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/args.txt">args</a></td>
102
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_log.txt">logs</a></td>
103
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_eval_linear_log.txt">eval</a></td>
104
+ </tr>
105
+ <tr>
106
+ <td>xcit_small_12_p8</td>
107
+ <td>26M</td>
108
+ <td>77.1%</td>
109
+ <td>79.2%</td>
110
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth">backbone only</a></td>
111
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_full_checkpoint.pth">full ckpt</a></td>
112
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/args.txt">args</a></td>
113
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_log.txt">logs</a></td>
114
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_eval_linear_log.txt">eval</a></td>
115
+ </tr>
116
+ <tr>
117
+ <td>xcit_medium_24_p16</td>
118
+ <td>84M</td>
119
+ <td>76.4%</td>
120
+ <td>78.8%</td>
121
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth">backbone only</a></td>
122
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_full_checkpoint.pth">full ckpt</a></td>
123
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/args.txt">args</a></td>
124
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_log.txt">logs</a></td>
125
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_eval_linear_log.txt">eval</a></td>
126
+ </tr>
127
+ <tr>
128
+ <td>xcit_medium_24_p8</td>
129
+ <td>84M</td>
130
+ <td>77.9%</td>
131
+ <td>80.3%</td>
132
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth">backbone only</a></td>
133
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_full_checkpoint.pth">full ckpt</a></td>
134
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/args.txt">args</a></td>
135
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_log.txt">logs</a></td>
136
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_eval_linear_log.txt">eval</a></td>
137
+ </tr>
138
+ </table>
139
+
140
+ ### Pretrained models on PyTorch Hub
141
+ ```python
142
+ import torch
143
+ vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
144
+ vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
145
+ vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
146
+ vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
147
+ xcit_small_12_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16')
148
+ xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')
149
+ xcit_medium_24_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')
150
+ xcit_medium_24_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
151
+ resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
152
+ ```
153
+
154
+ ## Training
155
+
156
+ ### Documentation
157
+ Please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the `args` column of the [pretrained models section](https://github.com/facebookresearch/dino#pretrained-models). For a glimpse at the full documentation of DINO training please run:
158
+ ```
159
+ python main_dino.py --help
160
+ ```
161
+
162
+ ### Vanilla DINO training :sauropod:
163
+ Run DINO with ViT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
164
+ ```
165
+ python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
166
+ ```
167
+
168
+ ### Multi-node training
169
+ We use Slurm and [submitit](https://github.com/facebookincubator/submitit) (`pip install submitit`). To train on 2 nodes with 8 GPUs each (total 16 GPUs):
170
+ ```
171
+ python run_with_submitit.py --nodes 2 --ngpus 8 --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
172
+ ```
173
+
174
+ <details>
175
+ <summary>
176
+ DINO with ViT-base network.
177
+ </summary>
178
+
179
+ ```
180
+ python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
181
+ ```
182
+
183
+ </details>
184
+
185
+ ### Boosting DINO performance :t-rex:
186
+ You can improve the performance of the vanilla run by:
187
+ - training for more epochs: `--epochs 300`,
188
+ - increasing the teacher temperature: `--teacher_temp 0.07 --warmup_teacher_temp_epochs 30`.
189
+ - removing last layer normalization (only safe with `--arch vit_small`): `--norm_last_layer false`,
190
+
191
+ <details>
192
+ <summary>
193
+ Full command.
194
+ </summary>
195
+
196
+ ```
197
+ python run_with_submitit.py --arch vit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
198
+ ```
199
+
200
+ </details>
201
+
202
+ The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
203
+
204
+ ### ResNet-50 and other convnets trainings
205
+ This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training logs](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) and [final checkpoint](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_checkpoint.pth) for this run.
206
+ ```
207
+ python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
208
+ ```
209
+
210
+ ## Self-attention visualization
211
+ You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:
212
+ ```
213
+ python visualize_attention.py
214
+ ```
215
+
216
+ <div align="center">
217
+ <img width="100%" alt="Self-attention from a Vision Transformer with 8x8 patches trained with DINO" src=".github/attention_maps.png">
218
+ </div>
219
+
220
+ ## Self-attention video generation
221
+ You can generate videos like the one on the blog post with `video_generation.py`.
222
+
223
+ https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4
224
+
225
+ Extract frames from input video and generate attention video:
226
+ ```
227
+ python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
228
+ --input_path input/video.mp4 \
229
+ --output_path output/ \
230
+ --fps 25
231
+ ```
232
+
233
+ Use folder of frames already extracted and generate attention video:
234
+ ```
235
+ python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
236
+ --input_path output/frames/ \
237
+ --output_path output/ \
238
+ --resize 256 \
239
+ ```
240
+
241
+ Only generate video from folder of attention maps images:
242
+ ```
243
+ python video_generation.py --input_path output/attention \
244
+ --output_path output/ \
245
+ --video_only \
246
+ --video_format avi
247
+ ```
248
+
249
+
250
+ ## Evaluation: k-NN classification on ImageNet
251
+ To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:
252
+ ```
253
+ python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet
254
+ ```
255
+ If you choose not to specify `--pretrained_weights`, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:
256
+ ```
257
+ python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet
258
+ ```
259
+
260
+ ## Evaluation: Linear classification on ImageNet
261
+ To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:
262
+ ```
263
+ python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet
264
+ ```
265
+
266
+ We release the logs and weights from evaluating the different models:
267
+
268
+ <table>
269
+ <tr>
270
+ <th>arch</th>
271
+ <th>top-1 ImageNet</th>
272
+ <th colspan="2">linear evaluation</th>
273
+ </tr>
274
+ <tr>
275
+ <td>ViT-S/16</td>
276
+ <td>77.0%</td>
277
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth">linear weights</a></td>
278
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_eval_linear_log.txt">logs</a></td>
279
+ </tr>
280
+ <tr>
281
+ <td>ViT-S/8</td>
282
+ <td>79.7%</td>
283
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth">linear weights</a></td>
284
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_eval_linear_log.txt">logs</a></td>
285
+ </tr>
286
+ <tr>
287
+ <td>ViT-B/16</td>
288
+ <td>78.2%</td>
289
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth">linear weights</a></td>
290
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_eval_linear_log.txt">logs</a></td>
291
+ </tr>
292
+ <tr>
293
+ <td>ViT-B/8</td>
294
+ <td>80.1%</td>
295
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth">linear weights</a></td>
296
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_eval_linear_log.txt">logs</a></td>
297
+ </tr>
298
+ <tr>
299
+ <td>xcit_small_12_p16</td>
300
+ <td>77.8%</td>
301
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_linearweights.pth">linear weights</a></td>
302
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_eval_linear_log.txt">logs</a></td>
303
+ </tr>
304
+ <tr>
305
+ <td>xcit_small_12_p8</td>
306
+ <td>79.2%</td>
307
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_linearweights.pth">linear weights</a></td>
308
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_eval_linear_log.txt">logs</a></td>
309
+ </tr>
310
+ <tr>
311
+ <td>xcit_medium_24_p16</td>
312
+ <td>78.8%</td>
313
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_linearweights.pth">linear weights</a></td>
314
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_eval_linear_log.txt">logs</a></td>
315
+ </tr>
316
+ <tr>
317
+ <td>xcit_medium_24_p8</td>
318
+ <td>80.3%</td>
319
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_linearweights.pth">linear weights</a></td>
320
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_eval_linear_log.txt">logs</a></td>
321
+ </tr>
322
+ <tr>
323
+ <td>ResNet-50</td>
324
+ <td>75.3%</td>
325
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_linearweights.pth">linear weights</a></td>
326
+ <td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_eval_linear_log.txt">logs</a></td>
327
+ </tr>
328
+ </table>
329
+
330
+ You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines:
331
+ ```
332
+ python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train
333
+ ```
334
+
335
+ ```
336
+ python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train
337
+ ```
338
+
339
+ ```
340
+ python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
341
+ ```
342
+
343
+ ```
344
+ python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
345
+ ```
346
+
347
+ ```
348
+ python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train
349
+ ```
350
+
351
+ ## Evaluation: DAVIS 2017 Video object segmentation
352
+ Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment.
353
+
354
+ **Step 1: Prepare DAVIS 2017 data**
355
+ ```
356
+ cd $HOME
357
+ git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017
358
+ ./data/get_davis.sh
359
+ ```
360
+
361
+ **Step 2: Video object segmentation**
362
+ ```
363
+ python eval_video_segmentation.py --data_path $HOME/davis-2017/DAVIS/ --output_dir /path/to/saving_dir
364
+ ```
365
+
366
+ **Step 3: Evaluate the obtained segmentation**
367
+ ```
368
+ git clone https://github.com/davisvideochallenge/davis2017-evaluation $HOME/davis2017-evaluation
369
+ python $HOME/davis2017-evaluation/evaluation_method.py --task semi-supervised --results_path /path/to/saving_dir --davis_path $HOME/davis-2017/DAVIS/
370
+ ```
371
+
372
+ ## Evaluation: Image Retrieval on revisited Oxford and Paris
373
+ Step 1: Prepare revisited Oxford and Paris by following [this repo](https://github.com/filipradenovic/revisitop).
374
+
375
+ Step 2: Image retrieval (if you do not specify weights with `--pretrained_weights` then by default [DINO weights pretrained on Google Landmark v2 dataset](https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth) will be used).
376
+
377
+ Paris:
378
+ ```
379
+ python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 512 --multiscale 1 --data_path /path/to/revisited_paris_oxford/ --dataset rparis6k
380
+ ```
381
+
382
+ Oxford:
383
+ ```
384
+ python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 224 --multiscale 0 --data_path /path/to/revisited_paris_oxford/ --dataset roxford5k
385
+ ```
386
+
387
+ ## Evaluation: Copy detection on Copydays
388
+ Step 1: Prepare [Copydays dataset](https://lear.inrialpes.fr/~jegou/data.php#copydays).
389
+
390
+ Step 2 (opt): Prepare a set of image distractors and a set of images on which to learn the whitening operator.
391
+ In our paper, we use 10k random images from YFCC100M as distractors and 20k random images from YFCC100M (different from the distractors) for computing the whitening operation.
392
+
393
+ Step 3: Run copy detection:
394
+ ```
395
+ python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_copy_detection.py --data_path /path/to/copydays/ --whitening_path /path/to/whitening_data/ --distractors_path /path/to/distractors/
396
+ ```
397
+ We report result on the strong subset. For example in the stdout from the command above we get: `eval on strong mAP=0.858`.
398
+
399
+ ## License
400
+ This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
401
+
402
+ ## Citation
403
+ If you find this repository useful, please consider giving a star :star: and citation :t-rex::
404
+ ```
405
+ @inproceedings{caron2021emerging,
406
+ title={Emerging Properties in Self-Supervised Vision Transformers},
407
+ author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
408
+ booktitle={Proceedings of the International Conference on Computer Vision (ICCV)},
409
+ year={2021}
410
+ }
411
+ ```
dino/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+ from os.path import dirname, join
3
+
4
+ sys.path.insert(0, join(dirname(__file__), "."))
dino/eval_copy_detection.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import pickle
17
+ import argparse
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.distributed as dist
22
+ import torch.backends.cudnn as cudnn
23
+ from torchvision import models as torchvision_models
24
+ from torchvision import transforms as pth_transforms
25
+ from PIL import Image, ImageFile
26
+ import numpy as np
27
+
28
+ import utils
29
+ import vision_transformer as vits
30
+ from eval_knn import extract_features
31
+
32
+
33
+ class CopydaysDataset:
34
+ def __init__(self, basedir):
35
+ self.basedir = basedir
36
+ self.block_names = (
37
+ ["original", "strong"]
38
+ + ["jpegqual/%d" % i for i in [3, 5, 8, 10, 15, 20, 30, 50, 75]]
39
+ + ["crops/%d" % i for i in [10, 15, 20, 30, 40, 50, 60, 70, 80]]
40
+ )
41
+ self.nblocks = len(self.block_names)
42
+
43
+ self.query_blocks = range(self.nblocks)
44
+ self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157
45
+ self.q_block_sizes[1] = 229
46
+ # search only among originals
47
+ self.database_blocks = [0]
48
+
49
+ def get_block(self, i):
50
+ dirname = self.basedir + "/" + self.block_names[i]
51
+ fnames = [
52
+ dirname + "/" + fname
53
+ for fname in sorted(os.listdir(dirname))
54
+ if fname.endswith(".jpg")
55
+ ]
56
+ return fnames
57
+
58
+ def get_block_filenames(self, subdir_name):
59
+ dirname = self.basedir + "/" + subdir_name
60
+ return [
61
+ fname for fname in sorted(os.listdir(dirname)) if fname.endswith(".jpg")
62
+ ]
63
+
64
+ def eval_result(self, ids, distances):
65
+ j0 = 0
66
+ for i in range(self.nblocks):
67
+ j1 = j0 + self.q_block_sizes[i]
68
+ block_name = self.block_names[i]
69
+ I = ids[j0:j1] # block size
70
+ sum_AP = 0
71
+ if block_name != "strong":
72
+ # 1:1 mapping of files to names
73
+ positives_per_query = [[i] for i in range(j1 - j0)]
74
+ else:
75
+ originals = self.get_block_filenames("original")
76
+ strongs = self.get_block_filenames("strong")
77
+
78
+ # check if prefixes match
79
+ positives_per_query = [
80
+ [j for j, bname in enumerate(originals) if bname[:4] == qname[:4]]
81
+ for qname in strongs
82
+ ]
83
+
84
+ for qno, Iline in enumerate(I):
85
+ positives = positives_per_query[qno]
86
+ ranks = []
87
+ for rank, bno in enumerate(Iline):
88
+ if bno in positives:
89
+ ranks.append(rank)
90
+ sum_AP += score_ap_from_ranks_1(ranks, len(positives))
91
+
92
+ print("eval on %s mAP=%.3f" % (block_name, sum_AP / (j1 - j0)))
93
+ j0 = j1
94
+
95
+
96
+ # from the Holidays evaluation package
97
+ def score_ap_from_ranks_1(ranks, nres):
98
+ """Compute the average precision of one search.
99
+ ranks = ordered list of ranks of true positives
100
+ nres = total number of positives in dataset
101
+ """
102
+
103
+ # accumulate trapezoids in PR-plot
104
+ ap = 0.0
105
+
106
+ # All have an x-size of:
107
+ recall_step = 1.0 / nres
108
+
109
+ for ntp, rank in enumerate(ranks):
110
+
111
+ # y-size on left side of trapezoid:
112
+ # ntp = nb of true positives so far
113
+ # rank = nb of retrieved items so far
114
+ if rank == 0:
115
+ precision_0 = 1.0
116
+ else:
117
+ precision_0 = ntp / float(rank)
118
+
119
+ # y-size on right side of trapezoid:
120
+ # ntp and rank are increased by one
121
+ precision_1 = (ntp + 1) / float(rank + 1)
122
+
123
+ ap += (precision_1 + precision_0) * recall_step / 2.0
124
+
125
+ return ap
126
+
127
+
128
+ class ImgListDataset(torch.utils.data.Dataset):
129
+ def __init__(self, img_list, transform=None):
130
+ self.samples = img_list
131
+ self.transform = transform
132
+
133
+ def __getitem__(self, i):
134
+ with open(self.samples[i], "rb") as f:
135
+ img = Image.open(f)
136
+ img = img.convert("RGB")
137
+ if self.transform is not None:
138
+ img = self.transform(img)
139
+ return img, i
140
+
141
+ def __len__(self):
142
+ return len(self.samples)
143
+
144
+
145
+ def is_image_file(s):
146
+ ext = s.split(".")[-1]
147
+ if ext in ["jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"]:
148
+ return True
149
+ return False
150
+
151
+
152
+ @torch.no_grad()
153
+ def extract_features(image_list, model, args):
154
+ transform = pth_transforms.Compose(
155
+ [
156
+ pth_transforms.Resize((args.imsize, args.imsize), interpolation=3),
157
+ pth_transforms.ToTensor(),
158
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
159
+ ]
160
+ )
161
+ tempdataset = ImgListDataset(image_list, transform=transform)
162
+ data_loader = torch.utils.data.DataLoader(
163
+ tempdataset,
164
+ batch_size=args.batch_size_per_gpu,
165
+ num_workers=args.num_workers,
166
+ drop_last=False,
167
+ sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False),
168
+ )
169
+ features = None
170
+ for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10):
171
+ samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True)
172
+ feats = model.get_intermediate_layers(samples, n=1)[0].clone()
173
+
174
+ cls_output_token = feats[:, 0, :] # [CLS] token
175
+ # GeM with exponent 4 for output patch tokens
176
+ b, h, w, d = (
177
+ len(samples),
178
+ int(samples.shape[-2] / model.patch_embed.patch_size),
179
+ int(samples.shape[-1] / model.patch_embed.patch_size),
180
+ feats.shape[-1],
181
+ )
182
+ feats = feats[:, 1:, :].reshape(b, h, w, d)
183
+ feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2)
184
+ feats = (
185
+ nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1.0 / 4).reshape(b, -1)
186
+ )
187
+ # concatenate [CLS] token and GeM pooled patch tokens
188
+ feats = torch.cat((cls_output_token, feats), dim=1)
189
+
190
+ # init storage feature matrix
191
+ if dist.get_rank() == 0 and features is None:
192
+ features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
193
+ if args.use_cuda:
194
+ features = features.cuda(non_blocking=True)
195
+
196
+ # get indexes from all processes
197
+ y_all = torch.empty(
198
+ dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device
199
+ )
200
+ y_l = list(y_all.unbind(0))
201
+ y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
202
+ y_all_reduce.wait()
203
+ index_all = torch.cat(y_l)
204
+
205
+ # share features between processes
206
+ feats_all = torch.empty(
207
+ dist.get_world_size(),
208
+ feats.size(0),
209
+ feats.size(1),
210
+ dtype=feats.dtype,
211
+ device=feats.device,
212
+ )
213
+ output_l = list(feats_all.unbind(0))
214
+ output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
215
+ output_all_reduce.wait()
216
+
217
+ # update storage feature matrix
218
+ if dist.get_rank() == 0:
219
+ if args.use_cuda:
220
+ features.index_copy_(0, index_all, torch.cat(output_l))
221
+ else:
222
+ features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
223
+ return features # features is still None for every rank which is not 0 (main)
224
+
225
+
226
+ if __name__ == "__main__":
227
+ parser = argparse.ArgumentParser("Copy detection on Copydays")
228
+ parser.add_argument(
229
+ "--data_path",
230
+ default="/path/to/copydays/",
231
+ type=str,
232
+ help="See https://lear.inrialpes.fr/~jegou/data.php#copydays",
233
+ )
234
+ parser.add_argument(
235
+ "--whitening_path",
236
+ default="/path/to/whitening_data/",
237
+ type=str,
238
+ help="""Path to directory with images used for computing the whitening operator.
239
+ In our paper, we use 20k random images from YFCC100M.""",
240
+ )
241
+ parser.add_argument(
242
+ "--distractors_path",
243
+ default="/path/to/distractors/",
244
+ type=str,
245
+ help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.",
246
+ )
247
+ parser.add_argument(
248
+ "--imsize", default=320, type=int, help="Image size (square image)"
249
+ )
250
+ parser.add_argument(
251
+ "--batch_size_per_gpu", default=16, type=int, help="Per-GPU batch-size"
252
+ )
253
+ parser.add_argument(
254
+ "--pretrained_weights",
255
+ default="",
256
+ type=str,
257
+ help="Path to pretrained weights to evaluate.",
258
+ )
259
+ parser.add_argument("--use_cuda", default=True, type=utils.bool_flag)
260
+ parser.add_argument("--arch", default="vit_base", type=str, help="Architecture")
261
+ parser.add_argument(
262
+ "--patch_size", default=8, type=int, help="Patch resolution of the model."
263
+ )
264
+ parser.add_argument(
265
+ "--checkpoint_key",
266
+ default="teacher",
267
+ type=str,
268
+ help='Key to use in the checkpoint (example: "teacher")',
269
+ )
270
+ parser.add_argument(
271
+ "--num_workers",
272
+ default=10,
273
+ type=int,
274
+ help="Number of data loading workers per GPU.",
275
+ )
276
+ parser.add_argument(
277
+ "--dist_url",
278
+ default="env://",
279
+ type=str,
280
+ help="""url used to set up
281
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""",
282
+ )
283
+ parser.add_argument(
284
+ "--local_rank",
285
+ default=0,
286
+ type=int,
287
+ help="Please ignore and do not set this argument.",
288
+ )
289
+ args = parser.parse_args()
290
+
291
+ utils.init_distributed_mode(args)
292
+ print("git:\n {}\n".format(utils.get_sha()))
293
+ print(
294
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
295
+ )
296
+ cudnn.benchmark = True
297
+
298
+ # ============ building network ... ============
299
+ if "vit" in args.arch:
300
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
301
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
302
+ else:
303
+ print(f"Architecture {args.arch} non supported")
304
+ sys.exit(1)
305
+ if args.use_cuda:
306
+ model.cuda()
307
+ model.eval()
308
+ utils.load_pretrained_weights(
309
+ model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
310
+ )
311
+
312
+ dataset = CopydaysDataset(args.data_path)
313
+
314
+ # ============ Extract features ... ============
315
+ # extract features for queries
316
+ queries = []
317
+ for q in dataset.query_blocks:
318
+ queries.append(extract_features(dataset.get_block(q), model, args))
319
+ if utils.get_rank() == 0:
320
+ queries = torch.cat(queries)
321
+ print(f"Extraction of queries features done. Shape: {queries.shape}")
322
+
323
+ # extract features for database
324
+ database = []
325
+ for b in dataset.database_blocks:
326
+ database.append(extract_features(dataset.get_block(b), model, args))
327
+
328
+ # extract features for distractors
329
+ if os.path.isdir(args.distractors_path):
330
+ print("Using distractors...")
331
+ list_distractors = [
332
+ os.path.join(args.distractors_path, s)
333
+ for s in os.listdir(args.distractors_path)
334
+ if is_image_file(s)
335
+ ]
336
+ database.append(extract_features(list_distractors, model, args))
337
+ if utils.get_rank() == 0:
338
+ database = torch.cat(database)
339
+ print(
340
+ f"Extraction of database and distractors features done. Shape: {database.shape}"
341
+ )
342
+
343
+ # ============ Whitening ... ============
344
+ if os.path.isdir(args.whitening_path):
345
+ print(
346
+ f"Extracting features on images from {args.whitening_path} for learning the whitening operator."
347
+ )
348
+ list_whit = [
349
+ os.path.join(args.whitening_path, s)
350
+ for s in os.listdir(args.whitening_path)
351
+ if is_image_file(s)
352
+ ]
353
+ features_for_whitening = extract_features(list_whit, model, args)
354
+ if utils.get_rank() == 0:
355
+ # center
356
+ mean_feature = torch.mean(features_for_whitening, dim=0)
357
+ database -= mean_feature
358
+ queries -= mean_feature
359
+ pca = utils.PCA(dim=database.shape[-1], whit=0.5)
360
+ # compute covariance
361
+ cov = (
362
+ torch.mm(features_for_whitening.T, features_for_whitening)
363
+ / features_for_whitening.shape[0]
364
+ )
365
+ pca.train_pca(cov.cpu().numpy())
366
+ database = pca.apply(database)
367
+ queries = pca.apply(queries)
368
+
369
+ # ============ Copy detection ... ============
370
+ if utils.get_rank() == 0:
371
+ # l2 normalize the features
372
+ database = nn.functional.normalize(database, dim=1, p=2)
373
+ queries = nn.functional.normalize(queries, dim=1, p=2)
374
+
375
+ # similarity
376
+ similarity = torch.mm(queries, database.T)
377
+ distances, indices = similarity.topk(20, largest=True, sorted=True)
378
+
379
+ # evaluate
380
+ retrieved = dataset.eval_result(indices, distances)
381
+ dist.barrier()
dino/eval_image_retrieval.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import pickle
17
+ import argparse
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.distributed as dist
22
+ import torch.backends.cudnn as cudnn
23
+ from torchvision import models as torchvision_models
24
+ from torchvision import transforms as pth_transforms
25
+ from PIL import Image, ImageFile
26
+ import numpy as np
27
+
28
+ import utils
29
+ import vision_transformer as vits
30
+ from eval_knn import extract_features
31
+
32
+
33
+ class OxfordParisDataset(torch.utils.data.Dataset):
34
+ def __init__(self, dir_main, dataset, split, transform=None, imsize=None):
35
+ if dataset not in ["roxford5k", "rparis6k"]:
36
+ raise ValueError("Unknown dataset: {}!".format(dataset))
37
+
38
+ # loading imlist, qimlist, and gnd, in cfg as a dict
39
+ gnd_fname = os.path.join(dir_main, dataset, "gnd_{}.pkl".format(dataset))
40
+ with open(gnd_fname, "rb") as f:
41
+ cfg = pickle.load(f)
42
+ cfg["gnd_fname"] = gnd_fname
43
+ cfg["ext"] = ".jpg"
44
+ cfg["qext"] = ".jpg"
45
+ cfg["dir_data"] = os.path.join(dir_main, dataset)
46
+ cfg["dir_images"] = os.path.join(cfg["dir_data"], "jpg")
47
+ cfg["n"] = len(cfg["imlist"])
48
+ cfg["nq"] = len(cfg["qimlist"])
49
+ cfg["im_fname"] = config_imname
50
+ cfg["qim_fname"] = config_qimname
51
+ cfg["dataset"] = dataset
52
+ self.cfg = cfg
53
+
54
+ self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"]
55
+ self.transform = transform
56
+ self.imsize = imsize
57
+
58
+ def __len__(self):
59
+ return len(self.samples)
60
+
61
+ def __getitem__(self, index):
62
+ path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg")
63
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
64
+ with open(path, "rb") as f:
65
+ img = Image.open(f)
66
+ img = img.convert("RGB")
67
+ if self.imsize is not None:
68
+ img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS)
69
+ if self.transform is not None:
70
+ img = self.transform(img)
71
+ return img, index
72
+
73
+
74
+ def config_imname(cfg, i):
75
+ return os.path.join(cfg["dir_images"], cfg["imlist"][i] + cfg["ext"])
76
+
77
+
78
+ def config_qimname(cfg, i):
79
+ return os.path.join(cfg["dir_images"], cfg["qimlist"][i] + cfg["qext"])
80
+
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser("Image Retrieval on revisited Paris and Oxford")
84
+ parser.add_argument(
85
+ "--data_path", default="/path/to/revisited_paris_oxford/", type=str
86
+ )
87
+ parser.add_argument(
88
+ "--dataset", default="roxford5k", type=str, choices=["roxford5k", "rparis6k"]
89
+ )
90
+ parser.add_argument("--multiscale", default=False, type=utils.bool_flag)
91
+ parser.add_argument("--imsize", default=224, type=int, help="Image size")
92
+ parser.add_argument(
93
+ "--pretrained_weights",
94
+ default="",
95
+ type=str,
96
+ help="Path to pretrained weights to evaluate.",
97
+ )
98
+ parser.add_argument("--use_cuda", default=True, type=utils.bool_flag)
99
+ parser.add_argument("--arch", default="vit_small", type=str, help="Architecture")
100
+ parser.add_argument(
101
+ "--patch_size", default=16, type=int, help="Patch resolution of the model."
102
+ )
103
+ parser.add_argument(
104
+ "--checkpoint_key",
105
+ default="teacher",
106
+ type=str,
107
+ help='Key to use in the checkpoint (example: "teacher")',
108
+ )
109
+ parser.add_argument(
110
+ "--num_workers",
111
+ default=10,
112
+ type=int,
113
+ help="Number of data loading workers per GPU.",
114
+ )
115
+ parser.add_argument(
116
+ "--dist_url",
117
+ default="env://",
118
+ type=str,
119
+ help="""url used to set up
120
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""",
121
+ )
122
+ parser.add_argument(
123
+ "--local_rank",
124
+ default=0,
125
+ type=int,
126
+ help="Please ignore and do not set this argument.",
127
+ )
128
+ args = parser.parse_args()
129
+
130
+ utils.init_distributed_mode(args)
131
+ print("git:\n {}\n".format(utils.get_sha()))
132
+ print(
133
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
134
+ )
135
+ cudnn.benchmark = True
136
+
137
+ # ============ preparing data ... ============
138
+ transform = pth_transforms.Compose(
139
+ [
140
+ pth_transforms.ToTensor(),
141
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
142
+ ]
143
+ )
144
+ dataset_train = OxfordParisDataset(
145
+ args.data_path,
146
+ args.dataset,
147
+ split="train",
148
+ transform=transform,
149
+ imsize=args.imsize,
150
+ )
151
+ dataset_query = OxfordParisDataset(
152
+ args.data_path,
153
+ args.dataset,
154
+ split="query",
155
+ transform=transform,
156
+ imsize=args.imsize,
157
+ )
158
+ sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
159
+ data_loader_train = torch.utils.data.DataLoader(
160
+ dataset_train,
161
+ sampler=sampler,
162
+ batch_size=1,
163
+ num_workers=args.num_workers,
164
+ pin_memory=True,
165
+ drop_last=False,
166
+ )
167
+ data_loader_query = torch.utils.data.DataLoader(
168
+ dataset_query,
169
+ batch_size=1,
170
+ num_workers=args.num_workers,
171
+ pin_memory=True,
172
+ drop_last=False,
173
+ )
174
+ print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs")
175
+
176
+ # ============ building network ... ============
177
+ if "vit" in args.arch:
178
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
179
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
180
+ elif "xcit" in args.arch:
181
+ model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0)
182
+ elif args.arch in torchvision_models.__dict__.keys():
183
+ model = torchvision_models.__dict__[args.arch](num_classes=0)
184
+ else:
185
+ print(f"Architecture {args.arch} non supported")
186
+ sys.exit(1)
187
+ if args.use_cuda:
188
+ model.cuda()
189
+ model.eval()
190
+
191
+ # load pretrained weights
192
+ if os.path.isfile(args.pretrained_weights):
193
+ state_dict = torch.load(args.pretrained_weights, map_location="cpu")
194
+ if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
195
+ print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
196
+ state_dict = state_dict[args.checkpoint_key]
197
+ # remove `module.` prefix
198
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
199
+ # remove `backbone.` prefix induced by multicrop wrapper
200
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
201
+ msg = model.load_state_dict(state_dict, strict=False)
202
+ print(
203
+ "Pretrained weights found at {} and loaded with msg: {}".format(
204
+ args.pretrained_weights, msg
205
+ )
206
+ )
207
+ elif args.arch == "vit_small" and args.patch_size == 16:
208
+ print(
209
+ "Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2."
210
+ )
211
+ model.load_state_dict(
212
+ torch.hub.load_state_dict_from_url(
213
+ url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"
214
+ )
215
+ )
216
+ else:
217
+ print("Warning: We use random weights.")
218
+
219
+ ############################################################################
220
+ # Step 1: extract features
221
+ train_features = extract_features(
222
+ model, data_loader_train, args.use_cuda, multiscale=args.multiscale
223
+ )
224
+ query_features = extract_features(
225
+ model, data_loader_query, args.use_cuda, multiscale=args.multiscale
226
+ )
227
+
228
+ if utils.get_rank() == 0: # only rank 0 will work from now on
229
+ # normalize features
230
+ train_features = nn.functional.normalize(train_features, dim=1, p=2)
231
+ query_features = nn.functional.normalize(query_features, dim=1, p=2)
232
+
233
+ ############################################################################
234
+ # Step 2: similarity
235
+ sim = torch.mm(train_features, query_features.T)
236
+ ranks = torch.argsort(-sim, dim=0).cpu().numpy()
237
+
238
+ ############################################################################
239
+ # Step 3: evaluate
240
+ gnd = dataset_train.cfg["gnd"]
241
+ # evaluate ranks
242
+ ks = [1, 5, 10]
243
+ # search for easy & hard
244
+ gnd_t = []
245
+ for i in range(len(gnd)):
246
+ g = {}
247
+ g["ok"] = np.concatenate([gnd[i]["easy"], gnd[i]["hard"]])
248
+ g["junk"] = np.concatenate([gnd[i]["junk"]])
249
+ gnd_t.append(g)
250
+ mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks)
251
+ # search for hard
252
+ gnd_t = []
253
+ for i in range(len(gnd)):
254
+ g = {}
255
+ g["ok"] = np.concatenate([gnd[i]["hard"]])
256
+ g["junk"] = np.concatenate([gnd[i]["junk"], gnd[i]["easy"]])
257
+ gnd_t.append(g)
258
+ mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks)
259
+ print(
260
+ ">> {}: mAP M: {}, H: {}".format(
261
+ args.dataset,
262
+ np.around(mapM * 100, decimals=2),
263
+ np.around(mapH * 100, decimals=2),
264
+ )
265
+ )
266
+ print(
267
+ ">> {}: mP@k{} M: {}, H: {}".format(
268
+ args.dataset,
269
+ np.array(ks),
270
+ np.around(mprM * 100, decimals=2),
271
+ np.around(mprH * 100, decimals=2),
272
+ )
273
+ )
274
+ dist.barrier()
dino/eval_knn.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import argparse
17
+
18
+ import torch
19
+ from torch import nn
20
+ import torch.distributed as dist
21
+ import torch.backends.cudnn as cudnn
22
+ from torchvision import datasets
23
+ from torchvision import transforms as pth_transforms
24
+ from torchvision import models as torchvision_models
25
+
26
+ import utils
27
+ import vision_transformer as vits
28
+
29
+
30
+ def extract_feature_pipeline(args):
31
+ # ============ preparing data ... ============
32
+ transform = pth_transforms.Compose(
33
+ [
34
+ pth_transforms.Resize(256, interpolation=3),
35
+ pth_transforms.CenterCrop(224),
36
+ pth_transforms.ToTensor(),
37
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
38
+ ]
39
+ )
40
+ dataset_train = ReturnIndexDataset(
41
+ os.path.join(args.data_path, "train"), transform=transform
42
+ )
43
+ dataset_val = ReturnIndexDataset(
44
+ os.path.join(args.data_path, "val"), transform=transform
45
+ )
46
+ sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
47
+ data_loader_train = torch.utils.data.DataLoader(
48
+ dataset_train,
49
+ sampler=sampler,
50
+ batch_size=args.batch_size_per_gpu,
51
+ num_workers=args.num_workers,
52
+ pin_memory=True,
53
+ drop_last=False,
54
+ )
55
+ data_loader_val = torch.utils.data.DataLoader(
56
+ dataset_val,
57
+ batch_size=args.batch_size_per_gpu,
58
+ num_workers=args.num_workers,
59
+ pin_memory=True,
60
+ drop_last=False,
61
+ )
62
+ print(
63
+ f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs."
64
+ )
65
+
66
+ # ============ building network ... ============
67
+ if "vit" in args.arch:
68
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
69
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
70
+ elif "xcit" in args.arch:
71
+ model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0)
72
+ elif args.arch in torchvision_models.__dict__.keys():
73
+ model = torchvision_models.__dict__[args.arch](num_classes=0)
74
+ model.fc = nn.Identity()
75
+ else:
76
+ print(f"Architecture {args.arch} non supported")
77
+ sys.exit(1)
78
+ model.cuda()
79
+ utils.load_pretrained_weights(
80
+ model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
81
+ )
82
+ model.eval()
83
+
84
+ # ============ extract features ... ============
85
+ print("Extracting features for train set...")
86
+ train_features = extract_features(model, data_loader_train, args.use_cuda)
87
+ print("Extracting features for val set...")
88
+ test_features = extract_features(model, data_loader_val, args.use_cuda)
89
+
90
+ if utils.get_rank() == 0:
91
+ train_features = nn.functional.normalize(train_features, dim=1, p=2)
92
+ test_features = nn.functional.normalize(test_features, dim=1, p=2)
93
+
94
+ train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
95
+ test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
96
+ # save features and labels
97
+ if args.dump_features and dist.get_rank() == 0:
98
+ torch.save(
99
+ train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")
100
+ )
101
+ torch.save(
102
+ test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")
103
+ )
104
+ torch.save(
105
+ train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")
106
+ )
107
+ torch.save(
108
+ test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")
109
+ )
110
+ return train_features, test_features, train_labels, test_labels
111
+
112
+
113
+ @torch.no_grad()
114
+ def extract_features(model, data_loader, use_cuda=True, multiscale=False):
115
+ metric_logger = utils.MetricLogger(delimiter=" ")
116
+ features = None
117
+ for samples, index in metric_logger.log_every(data_loader, 10):
118
+ samples = samples.cuda(non_blocking=True)
119
+ index = index.cuda(non_blocking=True)
120
+ if multiscale:
121
+ feats = utils.multi_scale(samples, model)
122
+ else:
123
+ feats = model(samples).clone()
124
+
125
+ # init storage feature matrix
126
+ if dist.get_rank() == 0 and features is None:
127
+ features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
128
+ if use_cuda:
129
+ features = features.cuda(non_blocking=True)
130
+ print(f"Storing features into tensor of shape {features.shape}")
131
+
132
+ # get indexes from all processes
133
+ y_all = torch.empty(
134
+ dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device
135
+ )
136
+ y_l = list(y_all.unbind(0))
137
+ y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
138
+ y_all_reduce.wait()
139
+ index_all = torch.cat(y_l)
140
+
141
+ # share features between processes
142
+ feats_all = torch.empty(
143
+ dist.get_world_size(),
144
+ feats.size(0),
145
+ feats.size(1),
146
+ dtype=feats.dtype,
147
+ device=feats.device,
148
+ )
149
+ output_l = list(feats_all.unbind(0))
150
+ output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
151
+ output_all_reduce.wait()
152
+
153
+ # update storage feature matrix
154
+ if dist.get_rank() == 0:
155
+ if use_cuda:
156
+ features.index_copy_(0, index_all, torch.cat(output_l))
157
+ else:
158
+ features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
159
+ return features
160
+
161
+
162
+ @torch.no_grad()
163
+ def knn_classifier(
164
+ train_features, train_labels, test_features, test_labels, k, T, num_classes=1000
165
+ ):
166
+ top1, top5, total = 0.0, 0.0, 0
167
+ train_features = train_features.t()
168
+ num_test_images, num_chunks = test_labels.shape[0], 100
169
+ imgs_per_chunk = num_test_images // num_chunks
170
+ retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
171
+ for idx in range(0, num_test_images, imgs_per_chunk):
172
+ # get the features for test images
173
+ features = test_features[idx : min((idx + imgs_per_chunk), num_test_images), :]
174
+ targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
175
+ batch_size = targets.shape[0]
176
+
177
+ # calculate the dot product and compute top-k neighbors
178
+ similarity = torch.mm(features, train_features)
179
+ distances, indices = similarity.topk(k, largest=True, sorted=True)
180
+ candidates = train_labels.view(1, -1).expand(batch_size, -1)
181
+ retrieved_neighbors = torch.gather(candidates, 1, indices)
182
+
183
+ retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
184
+ retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
185
+ distances_transform = distances.clone().div_(T).exp_()
186
+ probs = torch.sum(
187
+ torch.mul(
188
+ retrieval_one_hot.view(batch_size, -1, num_classes),
189
+ distances_transform.view(batch_size, -1, 1),
190
+ ),
191
+ 1,
192
+ )
193
+ _, predictions = probs.sort(1, True)
194
+
195
+ # find the predictions that match the target
196
+ correct = predictions.eq(targets.data.view(-1, 1))
197
+ top1 = top1 + correct.narrow(1, 0, 1).sum().item()
198
+ top5 = (
199
+ top5 + correct.narrow(1, 0, min(5, k)).sum().item()
200
+ ) # top5 does not make sense if k < 5
201
+ total += targets.size(0)
202
+ top1 = top1 * 100.0 / total
203
+ top5 = top5 * 100.0 / total
204
+ return top1, top5
205
+
206
+
207
+ class ReturnIndexDataset(datasets.ImageFolder):
208
+ def __getitem__(self, idx):
209
+ img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
210
+ return img, idx
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = argparse.ArgumentParser("Evaluation with weighted k-NN on ImageNet")
215
+ parser.add_argument(
216
+ "--batch_size_per_gpu", default=128, type=int, help="Per-GPU batch-size"
217
+ )
218
+ parser.add_argument(
219
+ "--nb_knn",
220
+ default=[10, 20, 100, 200],
221
+ nargs="+",
222
+ type=int,
223
+ help="Number of NN to use. 20 is usually working the best.",
224
+ )
225
+ parser.add_argument(
226
+ "--temperature",
227
+ default=0.07,
228
+ type=float,
229
+ help="Temperature used in the voting coefficient",
230
+ )
231
+ parser.add_argument(
232
+ "--pretrained_weights",
233
+ default="",
234
+ type=str,
235
+ help="Path to pretrained weights to evaluate.",
236
+ )
237
+ parser.add_argument(
238
+ "--use_cuda",
239
+ default=True,
240
+ type=utils.bool_flag,
241
+ help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM",
242
+ )
243
+ parser.add_argument("--arch", default="vit_small", type=str, help="Architecture")
244
+ parser.add_argument(
245
+ "--patch_size", default=16, type=int, help="Patch resolution of the model."
246
+ )
247
+ parser.add_argument(
248
+ "--checkpoint_key",
249
+ default="teacher",
250
+ type=str,
251
+ help='Key to use in the checkpoint (example: "teacher")',
252
+ )
253
+ parser.add_argument(
254
+ "--dump_features",
255
+ default=None,
256
+ help="Path where to save computed features, empty for no saving",
257
+ )
258
+ parser.add_argument(
259
+ "--load_features",
260
+ default=None,
261
+ help="""If the features have
262
+ already been computed, where to find them.""",
263
+ )
264
+ parser.add_argument(
265
+ "--num_workers",
266
+ default=10,
267
+ type=int,
268
+ help="Number of data loading workers per GPU.",
269
+ )
270
+ parser.add_argument(
271
+ "--dist_url",
272
+ default="env://",
273
+ type=str,
274
+ help="""url used to set up
275
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""",
276
+ )
277
+ parser.add_argument(
278
+ "--local_rank",
279
+ default=0,
280
+ type=int,
281
+ help="Please ignore and do not set this argument.",
282
+ )
283
+ parser.add_argument("--data_path", default="/path/to/imagenet/", type=str)
284
+ args = parser.parse_args()
285
+
286
+ utils.init_distributed_mode(args)
287
+ print("git:\n {}\n".format(utils.get_sha()))
288
+ print(
289
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
290
+ )
291
+ cudnn.benchmark = True
292
+
293
+ if args.load_features:
294
+ train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
295
+ test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
296
+ train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
297
+ test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
298
+ else:
299
+ # need to extract features !
300
+ (
301
+ train_features,
302
+ test_features,
303
+ train_labels,
304
+ test_labels,
305
+ ) = extract_feature_pipeline(args)
306
+
307
+ if utils.get_rank() == 0:
308
+ if args.use_cuda:
309
+ train_features = train_features.cuda()
310
+ test_features = test_features.cuda()
311
+ train_labels = train_labels.cuda()
312
+ test_labels = test_labels.cuda()
313
+
314
+ print("Features are ready!\nStart the k-NN classification.")
315
+ for k in args.nb_knn:
316
+ top1, top5 = knn_classifier(
317
+ train_features,
318
+ train_labels,
319
+ test_features,
320
+ test_labels,
321
+ k,
322
+ args.temperature,
323
+ )
324
+ print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
325
+ dist.barrier()
dino/eval_linear.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import argparse
16
+ import json
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ from torch import nn
21
+ import torch.distributed as dist
22
+ import torch.backends.cudnn as cudnn
23
+ from torchvision import datasets
24
+ from torchvision import transforms as pth_transforms
25
+ from torchvision import models as torchvision_models
26
+
27
+ import utils
28
+ import vision_transformer as vits
29
+
30
+
31
+ def eval_linear(args):
32
+ utils.init_distributed_mode(args)
33
+ print("git:\n {}\n".format(utils.get_sha()))
34
+ print(
35
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
36
+ )
37
+ cudnn.benchmark = True
38
+
39
+ # ============ building network ... ============
40
+ # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
41
+ if args.arch in vits.__dict__.keys():
42
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
43
+ embed_dim = model.embed_dim * (
44
+ args.n_last_blocks + int(args.avgpool_patchtokens)
45
+ )
46
+ # if the network is a XCiT
47
+ elif "xcit" in args.arch:
48
+ model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0)
49
+ embed_dim = model.embed_dim
50
+ # otherwise, we check if the architecture is in torchvision models
51
+ elif args.arch in torchvision_models.__dict__.keys():
52
+ model = torchvision_models.__dict__[args.arch]()
53
+ embed_dim = model.fc.weight.shape[1]
54
+ model.fc = nn.Identity()
55
+ else:
56
+ print(f"Unknow architecture: {args.arch}")
57
+ sys.exit(1)
58
+ model.cuda()
59
+ model.eval()
60
+ # load weights to evaluate
61
+ utils.load_pretrained_weights(
62
+ model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
63
+ )
64
+ print(f"Model {args.arch} built.")
65
+
66
+ linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
67
+ linear_classifier = linear_classifier.cuda()
68
+ linear_classifier = nn.parallel.DistributedDataParallel(
69
+ linear_classifier, device_ids=[args.gpu]
70
+ )
71
+
72
+ # ============ preparing data ... ============
73
+ val_transform = pth_transforms.Compose(
74
+ [
75
+ pth_transforms.Resize(256, interpolation=3),
76
+ pth_transforms.CenterCrop(224),
77
+ pth_transforms.ToTensor(),
78
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
79
+ ]
80
+ )
81
+ dataset_val = datasets.ImageFolder(
82
+ os.path.join(args.data_path, "val"), transform=val_transform
83
+ )
84
+ val_loader = torch.utils.data.DataLoader(
85
+ dataset_val,
86
+ batch_size=args.batch_size_per_gpu,
87
+ num_workers=args.num_workers,
88
+ pin_memory=True,
89
+ )
90
+
91
+ if args.evaluate:
92
+ utils.load_pretrained_linear_weights(
93
+ linear_classifier, args.arch, args.patch_size
94
+ )
95
+ test_stats = validate_network(
96
+ val_loader,
97
+ model,
98
+ linear_classifier,
99
+ args.n_last_blocks,
100
+ args.avgpool_patchtokens,
101
+ )
102
+ print(
103
+ f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
104
+ )
105
+ return
106
+
107
+ train_transform = pth_transforms.Compose(
108
+ [
109
+ pth_transforms.RandomResizedCrop(224),
110
+ pth_transforms.RandomHorizontalFlip(),
111
+ pth_transforms.ToTensor(),
112
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
113
+ ]
114
+ )
115
+ dataset_train = datasets.ImageFolder(
116
+ os.path.join(args.data_path, "train"), transform=train_transform
117
+ )
118
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
119
+ train_loader = torch.utils.data.DataLoader(
120
+ dataset_train,
121
+ sampler=sampler,
122
+ batch_size=args.batch_size_per_gpu,
123
+ num_workers=args.num_workers,
124
+ pin_memory=True,
125
+ )
126
+ print(
127
+ f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs."
128
+ )
129
+
130
+ # set optimizer
131
+ optimizer = torch.optim.SGD(
132
+ linear_classifier.parameters(),
133
+ args.lr
134
+ * (args.batch_size_per_gpu * utils.get_world_size())
135
+ / 256.0, # linear scaling rule
136
+ momentum=0.9,
137
+ weight_decay=0, # we do not apply weight decay
138
+ )
139
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
140
+ optimizer, args.epochs, eta_min=0
141
+ )
142
+
143
+ # Optionally resume from a checkpoint
144
+ to_restore = {"epoch": 0, "best_acc": 0.0}
145
+ utils.restart_from_checkpoint(
146
+ os.path.join(args.output_dir, "checkpoint.pth.tar"),
147
+ run_variables=to_restore,
148
+ state_dict=linear_classifier,
149
+ optimizer=optimizer,
150
+ scheduler=scheduler,
151
+ )
152
+ start_epoch = to_restore["epoch"]
153
+ best_acc = to_restore["best_acc"]
154
+
155
+ for epoch in range(start_epoch, args.epochs):
156
+ train_loader.sampler.set_epoch(epoch)
157
+
158
+ train_stats = train(
159
+ model,
160
+ linear_classifier,
161
+ optimizer,
162
+ train_loader,
163
+ epoch,
164
+ args.n_last_blocks,
165
+ args.avgpool_patchtokens,
166
+ )
167
+ scheduler.step()
168
+
169
+ log_stats = {
170
+ **{f"train_{k}": v for k, v in train_stats.items()},
171
+ "epoch": epoch,
172
+ }
173
+ if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
174
+ test_stats = validate_network(
175
+ val_loader,
176
+ model,
177
+ linear_classifier,
178
+ args.n_last_blocks,
179
+ args.avgpool_patchtokens,
180
+ )
181
+ print(
182
+ f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
183
+ )
184
+ best_acc = max(best_acc, test_stats["acc1"])
185
+ print(f"Max accuracy so far: {best_acc:.2f}%")
186
+ log_stats = {
187
+ **{k: v for k, v in log_stats.items()},
188
+ **{f"test_{k}": v for k, v in test_stats.items()},
189
+ }
190
+ if utils.is_main_process():
191
+ with (Path(args.output_dir) / "log.txt").open("a") as f:
192
+ f.write(json.dumps(log_stats) + "\n")
193
+ save_dict = {
194
+ "epoch": epoch + 1,
195
+ "state_dict": linear_classifier.state_dict(),
196
+ "optimizer": optimizer.state_dict(),
197
+ "scheduler": scheduler.state_dict(),
198
+ "best_acc": best_acc,
199
+ }
200
+ torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
201
+ print(
202
+ "Training of the supervised linear classifier on frozen features completed.\n"
203
+ "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)
204
+ )
205
+
206
+
207
+ def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
208
+ linear_classifier.train()
209
+ metric_logger = utils.MetricLogger(delimiter=" ")
210
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
211
+ header = "Epoch: [{}]".format(epoch)
212
+ for (inp, target) in metric_logger.log_every(loader, 20, header):
213
+ # move to gpu
214
+ inp = inp.cuda(non_blocking=True)
215
+ target = target.cuda(non_blocking=True)
216
+
217
+ # forward
218
+ with torch.no_grad():
219
+ if "vit" in args.arch:
220
+ intermediate_output = model.get_intermediate_layers(inp, n)
221
+ output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
222
+ if avgpool:
223
+ output = torch.cat(
224
+ (
225
+ output.unsqueeze(-1),
226
+ torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(
227
+ -1
228
+ ),
229
+ ),
230
+ dim=-1,
231
+ )
232
+ output = output.reshape(output.shape[0], -1)
233
+ else:
234
+ output = model(inp)
235
+ output = linear_classifier(output)
236
+
237
+ # compute cross entropy loss
238
+ loss = nn.CrossEntropyLoss()(output, target)
239
+
240
+ # compute the gradients
241
+ optimizer.zero_grad()
242
+ loss.backward()
243
+
244
+ # step
245
+ optimizer.step()
246
+
247
+ # log
248
+ torch.cuda.synchronize()
249
+ metric_logger.update(loss=loss.item())
250
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
251
+ # gather the stats from all processes
252
+ metric_logger.synchronize_between_processes()
253
+ print("Averaged stats:", metric_logger)
254
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
255
+
256
+
257
+ @torch.no_grad()
258
+ def validate_network(val_loader, model, linear_classifier, n, avgpool):
259
+ linear_classifier.eval()
260
+ metric_logger = utils.MetricLogger(delimiter=" ")
261
+ header = "Test:"
262
+ for inp, target in metric_logger.log_every(val_loader, 20, header):
263
+ # move to gpu
264
+ inp = inp.cuda(non_blocking=True)
265
+ target = target.cuda(non_blocking=True)
266
+
267
+ # forward
268
+ with torch.no_grad():
269
+ if "vit" in args.arch:
270
+ intermediate_output = model.get_intermediate_layers(inp, n)
271
+ output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
272
+ if avgpool:
273
+ output = torch.cat(
274
+ (
275
+ output.unsqueeze(-1),
276
+ torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(
277
+ -1
278
+ ),
279
+ ),
280
+ dim=-1,
281
+ )
282
+ output = output.reshape(output.shape[0], -1)
283
+ else:
284
+ output = model(inp)
285
+ output = linear_classifier(output)
286
+ loss = nn.CrossEntropyLoss()(output, target)
287
+
288
+ if linear_classifier.module.num_labels >= 5:
289
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
290
+ else:
291
+ (acc1,) = utils.accuracy(output, target, topk=(1,))
292
+
293
+ batch_size = inp.shape[0]
294
+ metric_logger.update(loss=loss.item())
295
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
296
+ if linear_classifier.module.num_labels >= 5:
297
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
298
+ if linear_classifier.module.num_labels >= 5:
299
+ print(
300
+ "* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format(
301
+ top1=metric_logger.acc1,
302
+ top5=metric_logger.acc5,
303
+ losses=metric_logger.loss,
304
+ )
305
+ )
306
+ else:
307
+ print(
308
+ "* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}".format(
309
+ top1=metric_logger.acc1, losses=metric_logger.loss
310
+ )
311
+ )
312
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
313
+
314
+
315
+ class LinearClassifier(nn.Module):
316
+ """Linear layer to train on top of frozen features"""
317
+
318
+ def __init__(self, dim, num_labels=1000):
319
+ super(LinearClassifier, self).__init__()
320
+ self.num_labels = num_labels
321
+ self.linear = nn.Linear(dim, num_labels)
322
+ self.linear.weight.data.normal_(mean=0.0, std=0.01)
323
+ self.linear.bias.data.zero_()
324
+
325
+ def forward(self, x):
326
+ # flatten
327
+ x = x.view(x.size(0), -1)
328
+
329
+ # linear layer
330
+ return self.linear(x)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ parser = argparse.ArgumentParser(
335
+ "Evaluation with linear classification on ImageNet"
336
+ )
337
+ parser.add_argument(
338
+ "--n_last_blocks",
339
+ default=4,
340
+ type=int,
341
+ help="""Concatenate [CLS] tokens
342
+ for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""",
343
+ )
344
+ parser.add_argument(
345
+ "--avgpool_patchtokens",
346
+ default=False,
347
+ type=utils.bool_flag,
348
+ help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
349
+ We typically set this to False for ViT-Small and to True with ViT-Base.""",
350
+ )
351
+ parser.add_argument("--arch", default="vit_small", type=str, help="Architecture")
352
+ parser.add_argument(
353
+ "--patch_size", default=16, type=int, help="Patch resolution of the model."
354
+ )
355
+ parser.add_argument(
356
+ "--pretrained_weights",
357
+ default="",
358
+ type=str,
359
+ help="Path to pretrained weights to evaluate.",
360
+ )
361
+ parser.add_argument(
362
+ "--checkpoint_key",
363
+ default="teacher",
364
+ type=str,
365
+ help='Key to use in the checkpoint (example: "teacher")',
366
+ )
367
+ parser.add_argument(
368
+ "--epochs", default=100, type=int, help="Number of epochs of training."
369
+ )
370
+ parser.add_argument(
371
+ "--lr",
372
+ default=0.001,
373
+ type=float,
374
+ help="""Learning rate at the beginning of
375
+ training (highest LR used during training). The learning rate is linearly scaled
376
+ with the batch size, and specified here for a reference batch size of 256.
377
+ We recommend tweaking the LR depending on the checkpoint evaluated.""",
378
+ )
379
+ parser.add_argument(
380
+ "--batch_size_per_gpu", default=128, type=int, help="Per-GPU batch-size"
381
+ )
382
+ parser.add_argument(
383
+ "--dist_url",
384
+ default="env://",
385
+ type=str,
386
+ help="""url used to set up
387
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""",
388
+ )
389
+ parser.add_argument(
390
+ "--local_rank",
391
+ default=0,
392
+ type=int,
393
+ help="Please ignore and do not set this argument.",
394
+ )
395
+ parser.add_argument("--data_path", default="/path/to/imagenet/", type=str)
396
+ parser.add_argument(
397
+ "--num_workers",
398
+ default=10,
399
+ type=int,
400
+ help="Number of data loading workers per GPU.",
401
+ )
402
+ parser.add_argument(
403
+ "--val_freq", default=1, type=int, help="Epoch frequency for validation."
404
+ )
405
+ parser.add_argument(
406
+ "--output_dir", default=".", help="Path to save logs and checkpoints"
407
+ )
408
+ parser.add_argument(
409
+ "--num_labels",
410
+ default=1000,
411
+ type=int,
412
+ help="Number of labels for linear classifier",
413
+ )
414
+ parser.add_argument(
415
+ "--evaluate",
416
+ dest="evaluate",
417
+ action="store_true",
418
+ help="evaluate model on validation set",
419
+ )
420
+ args = parser.parse_args()
421
+ eval_linear(args)
dino/eval_video_segmentation.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Some parts are taken from https://github.com/Liusifei/UVC
16
+ """
17
+ import os
18
+ import copy
19
+ import glob
20
+ import queue
21
+ from urllib.request import urlopen
22
+ import argparse
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ import cv2
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch.nn import functional as F
30
+ from PIL import Image
31
+ from torchvision import transforms
32
+
33
+ import utils
34
+ import vision_transformer as vits
35
+
36
+
37
+ @torch.no_grad()
38
+ def eval_video_tracking_davis(
39
+ args, model, frame_list, video_dir, first_seg, seg_ori, color_palette
40
+ ):
41
+ """
42
+ Evaluate tracking on a video given first frame & segmentation
43
+ """
44
+ video_folder = os.path.join(args.output_dir, video_dir.split("/")[-1])
45
+ os.makedirs(video_folder, exist_ok=True)
46
+
47
+ # The queue stores the n preceeding frames
48
+ que = queue.Queue(args.n_last_frames)
49
+
50
+ # first frame
51
+ frame1, ori_h, ori_w = read_frame(frame_list[0])
52
+ # extract first frame feature
53
+ frame1_feat = extract_feature(model, frame1).T # dim x h*w
54
+
55
+ # saving first segmentation
56
+ out_path = os.path.join(video_folder, "00000.png")
57
+ imwrite_indexed(out_path, seg_ori, color_palette)
58
+ mask_neighborhood = None
59
+ for cnt in tqdm(range(1, len(frame_list))):
60
+ frame_tar = read_frame(frame_list[cnt])[0]
61
+
62
+ # we use the first segmentation and the n previous ones
63
+ used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
64
+ used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
65
+
66
+ frame_tar_avg, feat_tar, mask_neighborhood = label_propagation(
67
+ args, model, frame_tar, used_frame_feats, used_segs, mask_neighborhood
68
+ )
69
+
70
+ # pop out oldest frame if neccessary
71
+ if que.qsize() == args.n_last_frames:
72
+ que.get()
73
+ # push current results into queue
74
+ seg = copy.deepcopy(frame_tar_avg)
75
+ que.put([feat_tar, seg])
76
+
77
+ # upsampling & argmax
78
+ frame_tar_avg = F.interpolate(
79
+ frame_tar_avg,
80
+ scale_factor=args.patch_size,
81
+ mode="bilinear",
82
+ align_corners=False,
83
+ recompute_scale_factor=False,
84
+ )[0]
85
+ frame_tar_avg = norm_mask(frame_tar_avg)
86
+ _, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
87
+
88
+ # saving to disk
89
+ frame_tar_seg = np.array(frame_tar_seg.squeeze().cpu(), dtype=np.uint8)
90
+ frame_tar_seg = np.array(
91
+ Image.fromarray(frame_tar_seg).resize((ori_w, ori_h), 0)
92
+ )
93
+ frame_nm = frame_list[cnt].split("/")[-1].replace(".jpg", ".png")
94
+ imwrite_indexed(
95
+ os.path.join(video_folder, frame_nm), frame_tar_seg, color_palette
96
+ )
97
+
98
+
99
+ def restrict_neighborhood(h, w):
100
+ # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
101
+ mask = torch.zeros(h, w, h, w)
102
+ for i in range(h):
103
+ for j in range(w):
104
+ for p in range(2 * args.size_mask_neighborhood + 1):
105
+ for q in range(2 * args.size_mask_neighborhood + 1):
106
+ if (
107
+ i - args.size_mask_neighborhood + p < 0
108
+ or i - args.size_mask_neighborhood + p >= h
109
+ ):
110
+ continue
111
+ if (
112
+ j - args.size_mask_neighborhood + q < 0
113
+ or j - args.size_mask_neighborhood + q >= w
114
+ ):
115
+ continue
116
+ mask[
117
+ i,
118
+ j,
119
+ i - args.size_mask_neighborhood + p,
120
+ j - args.size_mask_neighborhood + q,
121
+ ] = 1
122
+
123
+ mask = mask.reshape(h * w, h * w)
124
+ return mask.cuda(non_blocking=True)
125
+
126
+
127
+ def norm_mask(mask):
128
+ c, h, w = mask.size()
129
+ for cnt in range(c):
130
+ mask_cnt = mask[cnt, :, :]
131
+ if mask_cnt.max() > 0:
132
+ mask_cnt = mask_cnt - mask_cnt.min()
133
+ mask_cnt = mask_cnt / mask_cnt.max()
134
+ mask[cnt, :, :] = mask_cnt
135
+ return mask
136
+
137
+
138
+ def label_propagation(
139
+ args, model, frame_tar, list_frame_feats, list_segs, mask_neighborhood=None
140
+ ):
141
+ """
142
+ propagate segs of frames in list_frames to frame_tar
143
+ """
144
+ ## we only need to extract feature of the target frame
145
+ feat_tar, h, w = extract_feature(model, frame_tar, return_h_w=True)
146
+
147
+ return_feat_tar = feat_tar.T # dim x h*w
148
+
149
+ ncontext = len(list_frame_feats)
150
+ feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
151
+
152
+ feat_tar = F.normalize(feat_tar, dim=1, p=2)
153
+ feat_sources = F.normalize(feat_sources, dim=1, p=2)
154
+
155
+ feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
156
+ aff = torch.exp(
157
+ torch.bmm(feat_tar, feat_sources) / 0.1
158
+ ) # nmb_context x h*w (tar: query) x h*w (source: keys)
159
+
160
+ if args.size_mask_neighborhood > 0:
161
+ if mask_neighborhood is None:
162
+ mask_neighborhood = restrict_neighborhood(h, w)
163
+ mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
164
+ aff *= mask_neighborhood
165
+
166
+ aff = aff.transpose(2, 1).reshape(
167
+ -1, h * w
168
+ ) # nmb_context*h*w (source: keys) x h*w (tar: queries)
169
+ tk_val, _ = torch.topk(aff, dim=0, k=args.topk)
170
+ tk_val_min, _ = torch.min(tk_val, dim=0)
171
+ aff[aff < tk_val_min] = 0
172
+
173
+ aff = aff / torch.sum(aff, keepdim=True, axis=0)
174
+
175
+ list_segs = [s.cuda() for s in list_segs]
176
+ segs = torch.cat(list_segs)
177
+ nmb_context, C, h, w = segs.shape
178
+ segs = (
179
+ segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T
180
+ ) # C x nmb_context*h*w
181
+ seg_tar = torch.mm(segs, aff)
182
+ seg_tar = seg_tar.reshape(1, C, h, w)
183
+ return seg_tar, return_feat_tar, mask_neighborhood
184
+
185
+
186
+ def extract_feature(model, frame, return_h_w=False):
187
+ """Extract one frame feature everytime."""
188
+ out = model.get_intermediate_layers(frame.unsqueeze(0).cuda(), n=1)[0]
189
+ out = out[:, 1:, :] # we discard the [CLS] token
190
+ h, w = int(frame.shape[1] / model.patch_embed.patch_size), int(
191
+ frame.shape[2] / model.patch_embed.patch_size
192
+ )
193
+ dim = out.shape[-1]
194
+ out = out[0].reshape(h, w, dim)
195
+ out = out.reshape(-1, dim)
196
+ if return_h_w:
197
+ return out, h, w
198
+ return out
199
+
200
+
201
+ def imwrite_indexed(filename, array, color_palette):
202
+ """Save indexed png for DAVIS."""
203
+ if np.atleast_3d(array).shape[2] != 1:
204
+ raise Exception("Saving indexed PNGs requires 2D array.")
205
+
206
+ im = Image.fromarray(array)
207
+ im.putpalette(color_palette.ravel())
208
+ im.save(filename, format="PNG")
209
+
210
+
211
+ def to_one_hot(y_tensor, n_dims=None):
212
+ """
213
+ Take integer y (tensor or variable) with n dims &
214
+ convert it to 1-hot representation with n+1 dims.
215
+ """
216
+ if n_dims is None:
217
+ n_dims = int(y_tensor.max() + 1)
218
+ _, h, w = y_tensor.size()
219
+ y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
220
+ n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
221
+ y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
222
+ y_one_hot = y_one_hot.view(h, w, n_dims)
223
+ return y_one_hot.permute(2, 0, 1).unsqueeze(0)
224
+
225
+
226
+ def read_frame_list(video_dir):
227
+ frame_list = [img for img in glob.glob(os.path.join(video_dir, "*.jpg"))]
228
+ frame_list = sorted(frame_list)
229
+ return frame_list
230
+
231
+
232
+ def read_frame(frame_dir, scale_size=[480]):
233
+ """
234
+ read a single frame & preprocess
235
+ """
236
+ img = cv2.imread(frame_dir)
237
+ ori_h, ori_w, _ = img.shape
238
+ if len(scale_size) == 1:
239
+ if ori_h > ori_w:
240
+ tw = scale_size[0]
241
+ th = (tw * ori_h) / ori_w
242
+ th = int((th // 64) * 64)
243
+ else:
244
+ th = scale_size[0]
245
+ tw = (th * ori_w) / ori_h
246
+ tw = int((tw // 64) * 64)
247
+ else:
248
+ th, tw = scale_size
249
+ img = cv2.resize(img, (tw, th))
250
+ img = img.astype(np.float32)
251
+ img = img / 255.0
252
+ img = img[:, :, ::-1]
253
+ img = np.transpose(img.copy(), (2, 0, 1))
254
+ img = torch.from_numpy(img).float()
255
+ img = color_normalize(img)
256
+ return img, ori_h, ori_w
257
+
258
+
259
+ def read_seg(seg_dir, factor, scale_size=[480]):
260
+ seg = Image.open(seg_dir)
261
+ _w, _h = seg.size # note PIL.Image.Image's size is (w, h)
262
+ if len(scale_size) == 1:
263
+ if _w > _h:
264
+ _th = scale_size[0]
265
+ _tw = (_th * _w) / _h
266
+ _tw = int((_tw // 64) * 64)
267
+ else:
268
+ _tw = scale_size[0]
269
+ _th = (_tw * _h) / _w
270
+ _th = int((_th // 64) * 64)
271
+ else:
272
+ _th = scale_size[1]
273
+ _tw = scale_size[0]
274
+ small_seg = np.array(seg.resize((_tw // factor, _th // factor), 0))
275
+ small_seg = torch.from_numpy(small_seg.copy()).contiguous().float().unsqueeze(0)
276
+ return to_one_hot(small_seg), np.asarray(seg)
277
+
278
+
279
+ def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
280
+ for t, m, s in zip(x, mean, std):
281
+ t.sub_(m)
282
+ t.div_(s)
283
+ return x
284
+
285
+
286
+ if __name__ == "__main__":
287
+ parser = argparse.ArgumentParser(
288
+ "Evaluation with video object segmentation on DAVIS 2017"
289
+ )
290
+ parser.add_argument(
291
+ "--pretrained_weights",
292
+ default="",
293
+ type=str,
294
+ help="Path to pretrained weights to evaluate.",
295
+ )
296
+ parser.add_argument(
297
+ "--arch",
298
+ default="vit_small",
299
+ type=str,
300
+ choices=["vit_tiny", "vit_small", "vit_base"],
301
+ help="Architecture (support only ViT atm).",
302
+ )
303
+ parser.add_argument(
304
+ "--patch_size", default=16, type=int, help="Patch resolution of the model."
305
+ )
306
+ parser.add_argument(
307
+ "--checkpoint_key",
308
+ default="teacher",
309
+ type=str,
310
+ help='Key to use in the checkpoint (example: "teacher")',
311
+ )
312
+ parser.add_argument(
313
+ "--output_dir", default=".", help="Path where to save segmentations"
314
+ )
315
+ parser.add_argument("--data_path", default="/path/to/davis/", type=str)
316
+ parser.add_argument(
317
+ "--n_last_frames", type=int, default=7, help="number of preceeding frames"
318
+ )
319
+ parser.add_argument(
320
+ "--size_mask_neighborhood",
321
+ default=12,
322
+ type=int,
323
+ help="We restrict the set of source nodes considered to a spatial neighborhood of the query node",
324
+ )
325
+ parser.add_argument(
326
+ "--topk", type=int, default=5, help="accumulate label from top k neighbors"
327
+ )
328
+ parser.add_argument(
329
+ "--bs", type=int, default=6, help="Batch size, try to reduce if OOM"
330
+ )
331
+ args = parser.parse_args()
332
+
333
+ print("git:\n {}\n".format(utils.get_sha()))
334
+ print(
335
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
336
+ )
337
+
338
+ # building network
339
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
340
+ print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
341
+ model.cuda()
342
+ utils.load_pretrained_weights(
343
+ model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
344
+ )
345
+ for param in model.parameters():
346
+ param.requires_grad = False
347
+ model.eval()
348
+
349
+ color_palette = []
350
+ for line in urlopen(
351
+ "https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"
352
+ ):
353
+ color_palette.append(
354
+ [int(i) for i in line.decode("utf-8").split("\n")[0].split(" ")]
355
+ )
356
+ color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1, 3)
357
+
358
+ video_list = open(
359
+ os.path.join(args.data_path, "ImageSets/2017/val.txt")
360
+ ).readlines()
361
+ for i, video_name in enumerate(video_list):
362
+ video_name = video_name.strip()
363
+ print(f"[{i}/{len(video_list)}] Begin to segmentate video {video_name}.")
364
+ video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name)
365
+ frame_list = read_frame_list(video_dir)
366
+ seg_path = (
367
+ frame_list[0].replace("JPEGImages", "Annotations").replace("jpg", "png")
368
+ )
369
+ first_seg, seg_ori = read_seg(seg_path, args.patch_size)
370
+ eval_video_tracking_davis(
371
+ args, model, frame_list, video_dir, first_seg, seg_ori, color_palette
372
+ )
dino/hubconf.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from torchvision.models.resnet import resnet50
16
+
17
+ import vision_transformer as vits
18
+
19
+ dependencies = ["torch", "torchvision"]
20
+
21
+
22
+ def dino_vits16(pretrained=True, **kwargs):
23
+ """
24
+ ViT-Small/16x16 pre-trained with DINO.
25
+ Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
26
+ """
27
+ model = vits.__dict__["vit_small"](patch_size=16, num_classes=0, **kwargs)
28
+ if pretrained:
29
+ state_dict = torch.hub.load_state_dict_from_url(
30
+ url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
31
+ map_location="cpu",
32
+ )
33
+ model.load_state_dict(state_dict, strict=True)
34
+ return model
35
+
36
+
37
+ def dino_vits8(pretrained=True, **kwargs):
38
+ """
39
+ ViT-Small/8x8 pre-trained with DINO.
40
+ Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
41
+ """
42
+ model = vits.__dict__["vit_small"](patch_size=8, num_classes=0, **kwargs)
43
+ if pretrained:
44
+ state_dict = torch.hub.load_state_dict_from_url(
45
+ url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
46
+ map_location="cpu",
47
+ )
48
+ model.load_state_dict(state_dict, strict=True)
49
+ return model
50
+
51
+
52
+ def dino_vitb16(pretrained=True, **kwargs):
53
+ """
54
+ ViT-Base/16x16 pre-trained with DINO.
55
+ Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
56
+ """
57
+ model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs)
58
+ if pretrained:
59
+ state_dict = torch.hub.load_state_dict_from_url(
60
+ url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
61
+ map_location="cpu",
62
+ )
63
+ model.load_state_dict(state_dict, strict=True)
64
+ return model
65
+
66
+
67
+ def dino_vitb8(pretrained=True, **kwargs):
68
+ """
69
+ ViT-Base/8x8 pre-trained with DINO.
70
+ Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
71
+ """
72
+ model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
73
+ if pretrained:
74
+ state_dict = torch.hub.load_state_dict_from_url(
75
+ url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
76
+ map_location="cpu",
77
+ )
78
+ model.load_state_dict(state_dict, strict=True)
79
+ return model
80
+
81
+
82
+ def dino_resnet50(pretrained=True, **kwargs):
83
+ """
84
+ ResNet-50 pre-trained with DINO.
85
+ Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
86
+ """
87
+ model = resnet50(pretrained=False, **kwargs)
88
+ model.fc = torch.nn.Identity()
89
+ if pretrained:
90
+ state_dict = torch.hub.load_state_dict_from_url(
91
+ url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
92
+ map_location="cpu",
93
+ )
94
+ model.load_state_dict(state_dict, strict=False)
95
+ return model
96
+
97
+
98
+ def dino_xcit_small_12_p16(pretrained=True, **kwargs):
99
+ """
100
+ XCiT-Small-12/16 pre-trained with DINO.
101
+ """
102
+ model = torch.hub.load(
103
+ "facebookresearch/xcit:main", "xcit_small_12_p16", num_classes=0, **kwargs
104
+ )
105
+ if pretrained:
106
+ state_dict = torch.hub.load_state_dict_from_url(
107
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth",
108
+ map_location="cpu",
109
+ )
110
+ model.load_state_dict(state_dict, strict=True)
111
+ return model
112
+
113
+
114
+ def dino_xcit_small_12_p8(pretrained=True, **kwargs):
115
+ """
116
+ XCiT-Small-12/8 pre-trained with DINO.
117
+ """
118
+ model = torch.hub.load(
119
+ "facebookresearch/xcit:main", "xcit_small_12_p8", num_classes=0, **kwargs
120
+ )
121
+ if pretrained:
122
+ state_dict = torch.hub.load_state_dict_from_url(
123
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth",
124
+ map_location="cpu",
125
+ )
126
+ model.load_state_dict(state_dict, strict=True)
127
+ return model
128
+
129
+
130
+ def dino_xcit_medium_24_p16(pretrained=True, **kwargs):
131
+ """
132
+ XCiT-Medium-24/16 pre-trained with DINO.
133
+ """
134
+ model = torch.hub.load(
135
+ "facebookresearch/xcit:main", "xcit_medium_24_p16", num_classes=0, **kwargs
136
+ )
137
+ if pretrained:
138
+ state_dict = torch.hub.load_state_dict_from_url(
139
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
140
+ map_location="cpu",
141
+ )
142
+ model.load_state_dict(state_dict, strict=True)
143
+ return model
144
+
145
+
146
+ def dino_xcit_medium_24_p8(pretrained=True, **kwargs):
147
+ """
148
+ XCiT-Medium-24/8 pre-trained with DINO.
149
+ """
150
+ model = torch.hub.load(
151
+ "facebookresearch/xcit:main", "xcit_medium_24_p8", num_classes=0, **kwargs
152
+ )
153
+ if pretrained:
154
+ state_dict = torch.hub.load_state_dict_from_url(
155
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth",
156
+ map_location="cpu",
157
+ )
158
+ model.load_state_dict(state_dict, strict=True)
159
+ return model
dino/main_dino.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import os
16
+ import sys
17
+ import datetime
18
+ import time
19
+ import math
20
+ import json
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ from PIL import Image
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.distributed as dist
28
+ import torch.backends.cudnn as cudnn
29
+ import torch.nn.functional as F
30
+ from torchvision import datasets, transforms
31
+ from torchvision import models as torchvision_models
32
+
33
+ import utils
34
+ import vision_transformer as vits
35
+ from vision_transformer import DINOHead
36
+
37
+ torchvision_archs = sorted(
38
+ name
39
+ for name in torchvision_models.__dict__
40
+ if name.islower()
41
+ and not name.startswith("__")
42
+ and callable(torchvision_models.__dict__[name])
43
+ )
44
+
45
+
46
+ def get_args_parser():
47
+ parser = argparse.ArgumentParser("DINO", add_help=False)
48
+
49
+ # Model parameters
50
+ parser.add_argument(
51
+ "--arch",
52
+ default="vit_small",
53
+ type=str,
54
+ choices=["vit_tiny", "vit_small", "vit_base", "xcit", "deit_tiny", "deit_small"]
55
+ + torchvision_archs
56
+ + torch.hub.list("facebookresearch/xcit:main"),
57
+ help="""Name of architecture to train. For quick experiments with ViTs,
58
+ we recommend using vit_tiny or vit_small.""",
59
+ )
60
+ parser.add_argument(
61
+ "--patch_size",
62
+ default=16,
63
+ type=int,
64
+ help="""Size in pixels
65
+ of input square patches - default 16 (for 16x16 patches). Using smaller
66
+ values leads to better performance but requires more memory. Applies only
67
+ for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling
68
+ mixed precision training (--use_fp16 false) to avoid unstabilities.""",
69
+ )
70
+ parser.add_argument(
71
+ "--out_dim",
72
+ default=65536,
73
+ type=int,
74
+ help="""Dimensionality of
75
+ the DINO head output. For complex and large datasets large values (like 65k) work well.""",
76
+ )
77
+ parser.add_argument(
78
+ "--norm_last_layer",
79
+ default=True,
80
+ type=utils.bool_flag,
81
+ help="""Whether or not to weight normalize the last layer of the DINO head.
82
+ Not normalizing leads to better performance but can make the training unstable.
83
+ In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""",
84
+ )
85
+ parser.add_argument(
86
+ "--momentum_teacher",
87
+ default=0.996,
88
+ type=float,
89
+ help="""Base EMA
90
+ parameter for teacher update. The value is increased to 1 during training with cosine schedule.
91
+ We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""",
92
+ )
93
+ parser.add_argument(
94
+ "--use_bn_in_head",
95
+ default=False,
96
+ type=utils.bool_flag,
97
+ help="Whether to use batch normalizations in projection head (Default: False)",
98
+ )
99
+
100
+ # Temperature teacher parameters
101
+ parser.add_argument(
102
+ "--warmup_teacher_temp",
103
+ default=0.04,
104
+ type=float,
105
+ help="""Initial value for the teacher temperature: 0.04 works well in most cases.
106
+ Try decreasing it if the training loss does not decrease.""",
107
+ )
108
+ parser.add_argument(
109
+ "--teacher_temp",
110
+ default=0.04,
111
+ type=float,
112
+ help="""Final value (after linear warmup)
113
+ of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
114
+ starting with the default value of 0.04 and increase this slightly if needed.""",
115
+ )
116
+ parser.add_argument(
117
+ "--warmup_teacher_temp_epochs",
118
+ default=0,
119
+ type=int,
120
+ help="Number of warmup epochs for the teacher temperature (Default: 30).",
121
+ )
122
+
123
+ # Training/Optimization parameters
124
+ parser.add_argument(
125
+ "--use_fp16",
126
+ type=utils.bool_flag,
127
+ default=True,
128
+ help="""Whether or not
129
+ to use half precision for training. Improves training time and memory requirements,
130
+ but can provoke instability and slight decay of performance. We recommend disabling
131
+ mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""",
132
+ )
133
+ parser.add_argument(
134
+ "--weight_decay",
135
+ type=float,
136
+ default=0.04,
137
+ help="""Initial value of the
138
+ weight decay. With ViT, a smaller value at the beginning of training works well.""",
139
+ )
140
+ parser.add_argument(
141
+ "--weight_decay_end",
142
+ type=float,
143
+ default=0.4,
144
+ help="""Final value of the
145
+ weight decay. We use a cosine schedule for WD and using a larger decay by
146
+ the end of training improves performance for ViTs.""",
147
+ )
148
+ parser.add_argument(
149
+ "--clip_grad",
150
+ type=float,
151
+ default=3.0,
152
+ help="""Maximal parameter
153
+ gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
154
+ help optimization for larger ViT architectures. 0 for disabling.""",
155
+ )
156
+ parser.add_argument(
157
+ "--batch_size_per_gpu",
158
+ default=64,
159
+ type=int,
160
+ help="Per-GPU batch-size : number of distinct images loaded on one GPU.",
161
+ )
162
+ parser.add_argument(
163
+ "--epochs", default=100, type=int, help="Number of epochs of training."
164
+ )
165
+ parser.add_argument(
166
+ "--freeze_last_layer",
167
+ default=1,
168
+ type=int,
169
+ help="""Number of epochs
170
+ during which we keep the output layer fixed. Typically doing so during
171
+ the first epoch helps training. Try increasing this value if the loss does not decrease.""",
172
+ )
173
+ parser.add_argument(
174
+ "--lr",
175
+ default=0.0005,
176
+ type=float,
177
+ help="""Learning rate at the end of
178
+ linear warmup (highest LR used during training). The learning rate is linearly scaled
179
+ with the batch size, and specified here for a reference batch size of 256.""",
180
+ )
181
+ parser.add_argument(
182
+ "--warmup_epochs",
183
+ default=10,
184
+ type=int,
185
+ help="Number of epochs for the linear learning-rate warm up.",
186
+ )
187
+ parser.add_argument(
188
+ "--min_lr",
189
+ type=float,
190
+ default=1e-6,
191
+ help="""Target LR at the
192
+ end of optimization. We use a cosine LR schedule with linear warmup.""",
193
+ )
194
+ parser.add_argument(
195
+ "--optimizer",
196
+ default="adamw",
197
+ type=str,
198
+ choices=["adamw", "sgd", "lars"],
199
+ help="""Type of optimizer. We recommend using adamw with ViTs.""",
200
+ )
201
+ parser.add_argument(
202
+ "--drop_path_rate", type=float, default=0.1, help="stochastic depth rate"
203
+ )
204
+
205
+ # Multi-crop parameters
206
+ parser.add_argument(
207
+ "--global_crops_scale",
208
+ type=float,
209
+ nargs="+",
210
+ default=(0.4, 1.0),
211
+ help="""Scale range of the cropped image before resizing, relatively to the origin image.
212
+ Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
213
+ recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""",
214
+ )
215
+ parser.add_argument(
216
+ "--local_crops_number",
217
+ type=int,
218
+ default=8,
219
+ help="""Number of small
220
+ local views to generate. Set this parameter to 0 to disable multi-crop training.
221
+ When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """,
222
+ )
223
+ parser.add_argument(
224
+ "--local_crops_scale",
225
+ type=float,
226
+ nargs="+",
227
+ default=(0.05, 0.4),
228
+ help="""Scale range of the cropped image before resizing, relatively to the origin image.
229
+ Used for small local view cropping of multi-crop.""",
230
+ )
231
+
232
+ # Misc
233
+ parser.add_argument(
234
+ "--data_path",
235
+ default="/path/to/imagenet/train/",
236
+ type=str,
237
+ help="Please specify path to the ImageNet training data.",
238
+ )
239
+ parser.add_argument(
240
+ "--output_dir", default=".", type=str, help="Path to save logs and checkpoints."
241
+ )
242
+ parser.add_argument(
243
+ "--saveckp_freq", default=20, type=int, help="Save checkpoint every x epochs."
244
+ )
245
+ parser.add_argument("--seed", default=0, type=int, help="Random seed.")
246
+ parser.add_argument(
247
+ "--num_workers",
248
+ default=10,
249
+ type=int,
250
+ help="Number of data loading workers per GPU.",
251
+ )
252
+ parser.add_argument(
253
+ "--dist_url",
254
+ default="env://",
255
+ type=str,
256
+ help="""url used to set up
257
+ distributed training; see https://pytorch.org/docs/stable/distributed.html""",
258
+ )
259
+ parser.add_argument(
260
+ "--local_rank",
261
+ default=0,
262
+ type=int,
263
+ help="Please ignore and do not set this argument.",
264
+ )
265
+ return parser
266
+
267
+
268
+ def train_dino(args):
269
+ utils.init_distributed_mode(args)
270
+ utils.fix_random_seeds(args.seed)
271
+ print("git:\n {}\n".format(utils.get_sha()))
272
+ print(
273
+ "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
274
+ )
275
+ cudnn.benchmark = True
276
+
277
+ # ============ preparing data ... ============
278
+ transform = DataAugmentationDINO(
279
+ args.global_crops_scale,
280
+ args.local_crops_scale,
281
+ args.local_crops_number,
282
+ )
283
+ dataset = datasets.ImageFolder(args.data_path, transform=transform)
284
+ sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
285
+ data_loader = torch.utils.data.DataLoader(
286
+ dataset,
287
+ sampler=sampler,
288
+ batch_size=args.batch_size_per_gpu,
289
+ num_workers=args.num_workers,
290
+ pin_memory=True,
291
+ drop_last=True,
292
+ )
293
+ print(f"Data loaded: there are {len(dataset)} images.")
294
+
295
+ # ============ building student and teacher networks ... ============
296
+ # we changed the name DeiT-S for ViT-S to avoid confusions
297
+ args.arch = args.arch.replace("deit", "vit")
298
+ # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
299
+ if args.arch in vits.__dict__.keys():
300
+ student = vits.__dict__[args.arch](
301
+ patch_size=args.patch_size,
302
+ drop_path_rate=args.drop_path_rate, # stochastic depth
303
+ )
304
+ teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
305
+ embed_dim = student.embed_dim
306
+ # if the network is a XCiT
307
+ elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
308
+ student = torch.hub.load(
309
+ "facebookresearch/xcit:main",
310
+ args.arch,
311
+ pretrained=False,
312
+ drop_path_rate=args.drop_path_rate,
313
+ )
314
+ teacher = torch.hub.load(
315
+ "facebookresearch/xcit:main", args.arch, pretrained=False
316
+ )
317
+ embed_dim = student.embed_dim
318
+ # otherwise, we check if the architecture is in torchvision models
319
+ elif args.arch in torchvision_models.__dict__.keys():
320
+ student = torchvision_models.__dict__[args.arch]()
321
+ teacher = torchvision_models.__dict__[args.arch]()
322
+ embed_dim = student.fc.weight.shape[1]
323
+ else:
324
+ print(f"Unknow architecture: {args.arch}")
325
+
326
+ # multi-crop wrapper handles forward with inputs of different resolutions
327
+ student = utils.MultiCropWrapper(
328
+ student,
329
+ DINOHead(
330
+ embed_dim,
331
+ args.out_dim,
332
+ use_bn=args.use_bn_in_head,
333
+ norm_last_layer=args.norm_last_layer,
334
+ ),
335
+ )
336
+ teacher = utils.MultiCropWrapper(
337
+ teacher,
338
+ DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
339
+ )
340
+ # move networks to gpu
341
+ student, teacher = student.cuda(), teacher.cuda()
342
+ # synchronize batch norms (if any)
343
+ if utils.has_batchnorms(student):
344
+ student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
345
+ teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
346
+
347
+ # we need DDP wrapper to have synchro batch norms working...
348
+ teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
349
+ teacher_without_ddp = teacher.module
350
+ else:
351
+ # teacher_without_ddp and teacher are the same thing
352
+ teacher_without_ddp = teacher
353
+ student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
354
+ # teacher and student start with the same weights
355
+ teacher_without_ddp.load_state_dict(student.module.state_dict())
356
+ # there is no backpropagation through the teacher, so no need for gradients
357
+ for p in teacher.parameters():
358
+ p.requires_grad = False
359
+ print(f"Student and Teacher are built: they are both {args.arch} network.")
360
+
361
+ # ============ preparing loss ... ============
362
+ dino_loss = DINOLoss(
363
+ args.out_dim,
364
+ args.local_crops_number
365
+ + 2, # total number of crops = 2 global crops + local_crops_number
366
+ args.warmup_teacher_temp,
367
+ args.teacher_temp,
368
+ args.warmup_teacher_temp_epochs,
369
+ args.epochs,
370
+ ).cuda()
371
+
372
+ # ============ preparing optimizer ... ============
373
+ params_groups = utils.get_params_groups(student)
374
+ if args.optimizer == "adamw":
375
+ optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
376
+ elif args.optimizer == "sgd":
377
+ optimizer = torch.optim.SGD(
378
+ params_groups, lr=0, momentum=0.9
379
+ ) # lr is set by scheduler
380
+ elif args.optimizer == "lars":
381
+ optimizer = utils.LARS(params_groups) # to use with convnet and large batches
382
+ # for mixed precision training
383
+ fp16_scaler = None
384
+ if args.use_fp16:
385
+ fp16_scaler = torch.cuda.amp.GradScaler()
386
+
387
+ # ============ init schedulers ... ============
388
+ lr_schedule = utils.cosine_scheduler(
389
+ args.lr
390
+ * (args.batch_size_per_gpu * utils.get_world_size())
391
+ / 256.0, # linear scaling rule
392
+ args.min_lr,
393
+ args.epochs,
394
+ len(data_loader),
395
+ warmup_epochs=args.warmup_epochs,
396
+ )
397
+ wd_schedule = utils.cosine_scheduler(
398
+ args.weight_decay,
399
+ args.weight_decay_end,
400
+ args.epochs,
401
+ len(data_loader),
402
+ )
403
+ # momentum parameter is increased to 1. during training with a cosine schedule
404
+ momentum_schedule = utils.cosine_scheduler(
405
+ args.momentum_teacher, 1, args.epochs, len(data_loader)
406
+ )
407
+ print(f"Loss, optimizer and schedulers ready.")
408
+
409
+ # ============ optionally resume training ... ============
410
+ to_restore = {"epoch": 0}
411
+ utils.restart_from_checkpoint(
412
+ os.path.join(args.output_dir, "checkpoint.pth"),
413
+ run_variables=to_restore,
414
+ student=student,
415
+ teacher=teacher,
416
+ optimizer=optimizer,
417
+ fp16_scaler=fp16_scaler,
418
+ dino_loss=dino_loss,
419
+ )
420
+ start_epoch = to_restore["epoch"]
421
+
422
+ start_time = time.time()
423
+ print("Starting DINO training !")
424
+ for epoch in range(start_epoch, args.epochs):
425
+ data_loader.sampler.set_epoch(epoch)
426
+
427
+ # ============ training one epoch of DINO ... ============
428
+ train_stats = train_one_epoch(
429
+ student,
430
+ teacher,
431
+ teacher_without_ddp,
432
+ dino_loss,
433
+ data_loader,
434
+ optimizer,
435
+ lr_schedule,
436
+ wd_schedule,
437
+ momentum_schedule,
438
+ epoch,
439
+ fp16_scaler,
440
+ args,
441
+ )
442
+
443
+ # ============ writing logs ... ============
444
+ save_dict = {
445
+ "student": student.state_dict(),
446
+ "teacher": teacher.state_dict(),
447
+ "optimizer": optimizer.state_dict(),
448
+ "epoch": epoch + 1,
449
+ "args": args,
450
+ "dino_loss": dino_loss.state_dict(),
451
+ }
452
+ if fp16_scaler is not None:
453
+ save_dict["fp16_scaler"] = fp16_scaler.state_dict()
454
+ utils.save_on_master(save_dict, os.path.join(args.output_dir, "checkpoint.pth"))
455
+ if args.saveckp_freq and epoch % args.saveckp_freq == 0:
456
+ utils.save_on_master(
457
+ save_dict, os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth")
458
+ )
459
+ log_stats = {
460
+ **{f"train_{k}": v for k, v in train_stats.items()},
461
+ "epoch": epoch,
462
+ }
463
+ if utils.is_main_process():
464
+ with (Path(args.output_dir) / "log.txt").open("a") as f:
465
+ f.write(json.dumps(log_stats) + "\n")
466
+ total_time = time.time() - start_time
467
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
468
+ print("Training time {}".format(total_time_str))
469
+
470
+
471
+ def train_one_epoch(
472
+ student,
473
+ teacher,
474
+ teacher_without_ddp,
475
+ dino_loss,
476
+ data_loader,
477
+ optimizer,
478
+ lr_schedule,
479
+ wd_schedule,
480
+ momentum_schedule,
481
+ epoch,
482
+ fp16_scaler,
483
+ args,
484
+ ):
485
+ metric_logger = utils.MetricLogger(delimiter=" ")
486
+ header = "Epoch: [{}/{}]".format(epoch, args.epochs)
487
+ for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
488
+ # update weight decay and learning rate according to their schedule
489
+ it = len(data_loader) * epoch + it # global training iteration
490
+ for i, param_group in enumerate(optimizer.param_groups):
491
+ param_group["lr"] = lr_schedule[it]
492
+ if i == 0: # only the first group is regularized
493
+ param_group["weight_decay"] = wd_schedule[it]
494
+
495
+ # move images to gpu
496
+ images = [im.cuda(non_blocking=True) for im in images]
497
+ # teacher and student forward passes + compute dino loss
498
+ with torch.cuda.amp.autocast(fp16_scaler is not None):
499
+ teacher_output = teacher(
500
+ images[:2]
501
+ ) # only the 2 global views pass through the teacher
502
+ student_output = student(images)
503
+ loss = dino_loss(student_output, teacher_output, epoch)
504
+
505
+ if not math.isfinite(loss.item()):
506
+ print("Loss is {}, stopping training".format(loss.item()), force=True)
507
+ sys.exit(1)
508
+
509
+ # student update
510
+ optimizer.zero_grad()
511
+ param_norms = None
512
+ if fp16_scaler is None:
513
+ loss.backward()
514
+ if args.clip_grad:
515
+ param_norms = utils.clip_gradients(student, args.clip_grad)
516
+ utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
517
+ optimizer.step()
518
+ else:
519
+ fp16_scaler.scale(loss).backward()
520
+ if args.clip_grad:
521
+ fp16_scaler.unscale_(
522
+ optimizer
523
+ ) # unscale the gradients of optimizer's assigned params in-place
524
+ param_norms = utils.clip_gradients(student, args.clip_grad)
525
+ utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
526
+ fp16_scaler.step(optimizer)
527
+ fp16_scaler.update()
528
+
529
+ # EMA update for the teacher
530
+ with torch.no_grad():
531
+ m = momentum_schedule[it] # momentum parameter
532
+ for param_q, param_k in zip(
533
+ student.module.parameters(), teacher_without_ddp.parameters()
534
+ ):
535
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
536
+
537
+ # logging
538
+ torch.cuda.synchronize()
539
+ metric_logger.update(loss=loss.item())
540
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
541
+ metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
542
+ # gather the stats from all processes
543
+ metric_logger.synchronize_between_processes()
544
+ print("Averaged stats:", metric_logger)
545
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
546
+
547
+
548
+ class DINOLoss(nn.Module):
549
+ def __init__(
550
+ self,
551
+ out_dim,
552
+ ncrops,
553
+ warmup_teacher_temp,
554
+ teacher_temp,
555
+ warmup_teacher_temp_epochs,
556
+ nepochs,
557
+ student_temp=0.1,
558
+ center_momentum=0.9,
559
+ ):
560
+ super().__init__()
561
+ self.student_temp = student_temp
562
+ self.center_momentum = center_momentum
563
+ self.ncrops = ncrops
564
+ self.register_buffer("center", torch.zeros(1, out_dim))
565
+ # we apply a warm up for the teacher temperature because
566
+ # a too high temperature makes the training instable at the beginning
567
+ self.teacher_temp_schedule = np.concatenate(
568
+ (
569
+ np.linspace(
570
+ warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs
571
+ ),
572
+ np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp,
573
+ )
574
+ )
575
+
576
+ def forward(self, student_output, teacher_output, epoch):
577
+ """
578
+ Cross-entropy between softmax outputs of the teacher and student networks.
579
+ """
580
+ student_out = student_output / self.student_temp
581
+ student_out = student_out.chunk(self.ncrops)
582
+
583
+ # teacher centering and sharpening
584
+ temp = self.teacher_temp_schedule[epoch]
585
+ teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
586
+ teacher_out = teacher_out.detach().chunk(2)
587
+
588
+ total_loss = 0
589
+ n_loss_terms = 0
590
+ for iq, q in enumerate(teacher_out):
591
+ for v in range(len(student_out)):
592
+ if v == iq:
593
+ # we skip cases where student and teacher operate on the same view
594
+ continue
595
+ loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
596
+ total_loss += loss.mean()
597
+ n_loss_terms += 1
598
+ total_loss /= n_loss_terms
599
+ self.update_center(teacher_output)
600
+ return total_loss
601
+
602
+ @torch.no_grad()
603
+ def update_center(self, teacher_output):
604
+ """
605
+ Update center used for teacher output.
606
+ """
607
+ batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
608
+ dist.all_reduce(batch_center)
609
+ batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
610
+
611
+ # ema update
612
+ self.center = self.center * self.center_momentum + batch_center * (
613
+ 1 - self.center_momentum
614
+ )
615
+
616
+
617
+ class DataAugmentationDINO(object):
618
+ def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
619
+ flip_and_color_jitter = transforms.Compose(
620
+ [
621
+ transforms.RandomHorizontalFlip(p=0.5),
622
+ transforms.RandomApply(
623
+ [
624
+ transforms.ColorJitter(
625
+ brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
626
+ )
627
+ ],
628
+ p=0.8,
629
+ ),
630
+ transforms.RandomGrayscale(p=0.2),
631
+ ]
632
+ )
633
+ normalize = transforms.Compose(
634
+ [
635
+ transforms.ToTensor(),
636
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
637
+ ]
638
+ )
639
+
640
+ # first global crop
641
+ self.global_transfo1 = transforms.Compose(
642
+ [
643
+ transforms.RandomResizedCrop(
644
+ 224, scale=global_crops_scale, interpolation=Image.BICUBIC
645
+ ),
646
+ flip_and_color_jitter,
647
+ utils.GaussianBlur(1.0),
648
+ normalize,
649
+ ]
650
+ )
651
+ # second global crop
652
+ self.global_transfo2 = transforms.Compose(
653
+ [
654
+ transforms.RandomResizedCrop(
655
+ 224, scale=global_crops_scale, interpolation=Image.BICUBIC
656
+ ),
657
+ flip_and_color_jitter,
658
+ utils.GaussianBlur(0.1),
659
+ utils.Solarization(0.2),
660
+ normalize,
661
+ ]
662
+ )
663
+ # transformation for the local small crops
664
+ self.local_crops_number = local_crops_number
665
+ self.local_transfo = transforms.Compose(
666
+ [
667
+ transforms.RandomResizedCrop(
668
+ 96, scale=local_crops_scale, interpolation=Image.BICUBIC
669
+ ),
670
+ flip_and_color_jitter,
671
+ utils.GaussianBlur(p=0.5),
672
+ normalize,
673
+ ]
674
+ )
675
+
676
+ def __call__(self, image):
677
+ crops = []
678
+ crops.append(self.global_transfo1(image))
679
+ crops.append(self.global_transfo2(image))
680
+ for _ in range(self.local_crops_number):
681
+ crops.append(self.local_transfo(image))
682
+ return crops
683
+
684
+
685
+ if __name__ == "__main__":
686
+ parser = argparse.ArgumentParser("DINO", parents=[get_args_parser()])
687
+ args = parser.parse_args()
688
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
689
+ train_dino(args)
dino/run_with_submitit.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ A script to run multinode training with submitit.
16
+ Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
17
+ """
18
+ import argparse
19
+ import os
20
+ import uuid
21
+ from pathlib import Path
22
+
23
+ import main_dino
24
+ import submitit
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser(
29
+ "Submitit for DINO", parents=[main_dino.get_args_parser()]
30
+ )
31
+ parser.add_argument(
32
+ "--ngpus", default=8, type=int, help="Number of gpus to request on each node"
33
+ )
34
+ parser.add_argument(
35
+ "--nodes", default=2, type=int, help="Number of nodes to request"
36
+ )
37
+ parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
38
+
39
+ parser.add_argument(
40
+ "--partition", default="learnfair", type=str, help="Partition where to submit"
41
+ )
42
+ parser.add_argument(
43
+ "--use_volta32", action="store_true", help="Big models? Use this"
44
+ )
45
+ parser.add_argument(
46
+ "--comment",
47
+ default="",
48
+ type=str,
49
+ help="Comment to pass to scheduler, e.g. priority message",
50
+ )
51
+ return parser.parse_args()
52
+
53
+
54
+ def get_shared_folder() -> Path:
55
+ user = os.getenv("USER")
56
+ if Path("/checkpoint/").is_dir():
57
+ p = Path(f"/checkpoint/{user}/experiments")
58
+ p.mkdir(exist_ok=True)
59
+ return p
60
+ raise RuntimeError("No shared folder available")
61
+
62
+
63
+ def get_init_file():
64
+ # Init file must not exist, but it's parent dir must exist.
65
+ os.makedirs(str(get_shared_folder()), exist_ok=True)
66
+ init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
67
+ if init_file.exists():
68
+ os.remove(str(init_file))
69
+ return init_file
70
+
71
+
72
+ class Trainer(object):
73
+ def __init__(self, args):
74
+ self.args = args
75
+
76
+ def __call__(self):
77
+ import main_dino
78
+
79
+ self._setup_gpu_args()
80
+ main_dino.train_dino(self.args)
81
+
82
+ def checkpoint(self):
83
+ import os
84
+ import submitit
85
+
86
+ self.args.dist_url = get_init_file().as_uri()
87
+ print("Requeuing ", self.args)
88
+ empty_trainer = type(self)(self.args)
89
+ return submitit.helpers.DelayedSubmission(empty_trainer)
90
+
91
+ def _setup_gpu_args(self):
92
+ import submitit
93
+ from pathlib import Path
94
+
95
+ job_env = submitit.JobEnvironment()
96
+ self.args.output_dir = Path(
97
+ str(self.args.output_dir).replace("%j", str(job_env.job_id))
98
+ )
99
+ self.args.gpu = job_env.local_rank
100
+ self.args.rank = job_env.global_rank
101
+ self.args.world_size = job_env.num_tasks
102
+ print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
103
+
104
+
105
+ def main():
106
+ args = parse_args()
107
+ if args.output_dir == "":
108
+ args.output_dir = get_shared_folder() / "%j"
109
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
110
+ executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
111
+
112
+ num_gpus_per_node = args.ngpus
113
+ nodes = args.nodes
114
+ timeout_min = args.timeout
115
+
116
+ partition = args.partition
117
+ kwargs = {}
118
+ if args.use_volta32:
119
+ kwargs["slurm_constraint"] = "volta32gb"
120
+ if args.comment:
121
+ kwargs["slurm_comment"] = args.comment
122
+
123
+ executor.update_parameters(
124
+ mem_gb=40 * num_gpus_per_node,
125
+ gpus_per_node=num_gpus_per_node,
126
+ tasks_per_node=num_gpus_per_node, # one task per GPU
127
+ cpus_per_task=10,
128
+ nodes=nodes,
129
+ timeout_min=timeout_min, # max is 60 * 72
130
+ # Below are cluster dependent parameters
131
+ slurm_partition=partition,
132
+ slurm_signal_delay_s=120,
133
+ **kwargs,
134
+ )
135
+
136
+ executor.update_parameters(name="dino")
137
+
138
+ args.dist_url = get_init_file().as_uri()
139
+
140
+ trainer = Trainer(args)
141
+ job = executor.submit(trainer)
142
+
143
+ print(f"Submitted job_id: {job.job_id}")
144
+ print(f"Logs and checkpoints will be saved at: {args.output_dir}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
dino/utils.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Misc functions.
16
+
17
+ Mostly copy-paste from torchvision references or other public repos like DETR:
18
+ https://github.com/facebookresearch/detr/blob/master/util/misc.py
19
+ """
20
+ import os
21
+ import sys
22
+ import time
23
+ import math
24
+ import random
25
+ import datetime
26
+ import subprocess
27
+ from collections import defaultdict, deque
28
+
29
+ import numpy as np
30
+ import torch
31
+ from torch import nn
32
+ import torch.distributed as dist
33
+ from PIL import ImageFilter, ImageOps
34
+
35
+
36
+ class GaussianBlur(object):
37
+ """
38
+ Apply Gaussian Blur to the PIL image.
39
+ """
40
+
41
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
42
+ self.prob = p
43
+ self.radius_min = radius_min
44
+ self.radius_max = radius_max
45
+
46
+ def __call__(self, img):
47
+ do_it = random.random() <= self.prob
48
+ if not do_it:
49
+ return img
50
+
51
+ return img.filter(
52
+ ImageFilter.GaussianBlur(
53
+ radius=random.uniform(self.radius_min, self.radius_max)
54
+ )
55
+ )
56
+
57
+
58
+ class Solarization(object):
59
+ """
60
+ Apply Solarization to the PIL image.
61
+ """
62
+
63
+ def __init__(self, p):
64
+ self.p = p
65
+
66
+ def __call__(self, img):
67
+ if random.random() < self.p:
68
+ return ImageOps.solarize(img)
69
+ else:
70
+ return img
71
+
72
+
73
+ def load_pretrained_weights(
74
+ model, pretrained_weights, checkpoint_key, model_name, patch_size
75
+ ):
76
+ if os.path.isfile(pretrained_weights):
77
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
78
+ if checkpoint_key is not None and checkpoint_key in state_dict:
79
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
80
+ state_dict = state_dict[checkpoint_key]
81
+ # remove `module.` prefix
82
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
83
+ # remove `backbone.` prefix induced by multicrop wrapper
84
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
85
+ msg = model.load_state_dict(state_dict, strict=False)
86
+ print(
87
+ "Pretrained weights found at {} and loaded with msg: {}".format(
88
+ pretrained_weights, msg
89
+ )
90
+ )
91
+ else:
92
+ print(
93
+ "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
94
+ )
95
+ url = None
96
+ if model_name == "vit_small" and patch_size == 16:
97
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
98
+ elif model_name == "vit_small" and patch_size == 8:
99
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
100
+ elif model_name == "vit_base" and patch_size == 16:
101
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
102
+ elif model_name == "vit_base" and patch_size == 8:
103
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
104
+ elif model_name == "xcit_small_12_p16":
105
+ url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
106
+ elif model_name == "xcit_small_12_p8":
107
+ url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
108
+ elif model_name == "xcit_medium_24_p16":
109
+ url = (
110
+ "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
111
+ )
112
+ elif model_name == "xcit_medium_24_p8":
113
+ url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
114
+ elif model_name == "resnet50":
115
+ url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
116
+ if url is not None:
117
+ print(
118
+ "Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
119
+ )
120
+ state_dict = torch.hub.load_state_dict_from_url(
121
+ url="https://dl.fbaipublicfiles.com/dino/" + url
122
+ )
123
+ model.load_state_dict(state_dict, strict=True)
124
+ else:
125
+ print(
126
+ "There is no reference weights available for this model => We use random weights."
127
+ )
128
+
129
+
130
+ def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
131
+ url = None
132
+ if model_name == "vit_small" and patch_size == 16:
133
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
134
+ elif model_name == "vit_small" and patch_size == 8:
135
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
136
+ elif model_name == "vit_base" and patch_size == 16:
137
+ url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
138
+ elif model_name == "vit_base" and patch_size == 8:
139
+ url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
140
+ elif model_name == "resnet50":
141
+ url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
142
+ if url is not None:
143
+ print("We load the reference pretrained linear weights.")
144
+ state_dict = torch.hub.load_state_dict_from_url(
145
+ url="https://dl.fbaipublicfiles.com/dino/" + url
146
+ )["state_dict"]
147
+ linear_classifier.load_state_dict(state_dict, strict=True)
148
+ else:
149
+ print("We use random linear weights.")
150
+
151
+
152
+ def clip_gradients(model, clip):
153
+ norms = []
154
+ for name, p in model.named_parameters():
155
+ if p.grad is not None:
156
+ param_norm = p.grad.data.norm(2)
157
+ norms.append(param_norm.item())
158
+ clip_coef = clip / (param_norm + 1e-6)
159
+ if clip_coef < 1:
160
+ p.grad.data.mul_(clip_coef)
161
+ return norms
162
+
163
+
164
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
165
+ if epoch >= freeze_last_layer:
166
+ return
167
+ for n, p in model.named_parameters():
168
+ if "last_layer" in n:
169
+ p.grad = None
170
+
171
+
172
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
173
+ """
174
+ Re-start from checkpoint
175
+ """
176
+ if not os.path.isfile(ckp_path):
177
+ return
178
+ print("Found checkpoint at {}".format(ckp_path))
179
+
180
+ # open checkpoint file
181
+ checkpoint = torch.load(ckp_path, map_location="cpu")
182
+
183
+ # key is what to look for in the checkpoint file
184
+ # value is the object to load
185
+ # example: {'state_dict': model}
186
+ for key, value in kwargs.items():
187
+ if key in checkpoint and value is not None:
188
+ try:
189
+ msg = value.load_state_dict(checkpoint[key], strict=False)
190
+ print(
191
+ "=> loaded '{}' from checkpoint '{}' with msg {}".format(
192
+ key, ckp_path, msg
193
+ )
194
+ )
195
+ except TypeError:
196
+ try:
197
+ msg = value.load_state_dict(checkpoint[key])
198
+ print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
199
+ except ValueError:
200
+ print(
201
+ "=> failed to load '{}' from checkpoint: '{}'".format(
202
+ key, ckp_path
203
+ )
204
+ )
205
+ else:
206
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
207
+
208
+ # re load variable important for the run
209
+ if run_variables is not None:
210
+ for var_name in run_variables:
211
+ if var_name in checkpoint:
212
+ run_variables[var_name] = checkpoint[var_name]
213
+
214
+
215
+ def cosine_scheduler(
216
+ base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0
217
+ ):
218
+ warmup_schedule = np.array([])
219
+ warmup_iters = warmup_epochs * niter_per_ep
220
+ if warmup_epochs > 0:
221
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
222
+
223
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
224
+ schedule = final_value + 0.5 * (base_value - final_value) * (
225
+ 1 + np.cos(np.pi * iters / len(iters))
226
+ )
227
+
228
+ schedule = np.concatenate((warmup_schedule, schedule))
229
+ assert len(schedule) == epochs * niter_per_ep
230
+ return schedule
231
+
232
+
233
+ def bool_flag(s):
234
+ """
235
+ Parse boolean arguments from the command line.
236
+ """
237
+ FALSY_STRINGS = {"off", "false", "0"}
238
+ TRUTHY_STRINGS = {"on", "true", "1"}
239
+ if s.lower() in FALSY_STRINGS:
240
+ return False
241
+ elif s.lower() in TRUTHY_STRINGS:
242
+ return True
243
+ else:
244
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
245
+
246
+
247
+ def fix_random_seeds(seed=31):
248
+ """
249
+ Fix random seeds.
250
+ """
251
+ torch.manual_seed(seed)
252
+ torch.cuda.manual_seed_all(seed)
253
+ np.random.seed(seed)
254
+
255
+
256
+ class SmoothedValue(object):
257
+ """Track a series of values and provide access to smoothed values over a
258
+ window or the global series average.
259
+ """
260
+
261
+ def __init__(self, window_size=20, fmt=None):
262
+ if fmt is None:
263
+ fmt = "{median:.6f} ({global_avg:.6f})"
264
+ self.deque = deque(maxlen=window_size)
265
+ self.total = 0.0
266
+ self.count = 0
267
+ self.fmt = fmt
268
+
269
+ def update(self, value, n=1):
270
+ self.deque.append(value)
271
+ self.count += n
272
+ self.total += value * n
273
+
274
+ def synchronize_between_processes(self):
275
+ """
276
+ Warning: does not synchronize the deque!
277
+ """
278
+ if not is_dist_avail_and_initialized():
279
+ return
280
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
281
+ dist.barrier()
282
+ dist.all_reduce(t)
283
+ t = t.tolist()
284
+ self.count = int(t[0])
285
+ self.total = t[1]
286
+
287
+ @property
288
+ def median(self):
289
+ d = torch.tensor(list(self.deque))
290
+ return d.median().item()
291
+
292
+ @property
293
+ def avg(self):
294
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
295
+ return d.mean().item()
296
+
297
+ @property
298
+ def global_avg(self):
299
+ return self.total / self.count
300
+
301
+ @property
302
+ def max(self):
303
+ return max(self.deque)
304
+
305
+ @property
306
+ def value(self):
307
+ return self.deque[-1]
308
+
309
+ def __str__(self):
310
+ return self.fmt.format(
311
+ median=self.median,
312
+ avg=self.avg,
313
+ global_avg=self.global_avg,
314
+ max=self.max,
315
+ value=self.value,
316
+ )
317
+
318
+
319
+ def reduce_dict(input_dict, average=True):
320
+ """
321
+ Args:
322
+ input_dict (dict): all the values will be reduced
323
+ average (bool): whether to do average or sum
324
+ Reduce the values in the dictionary from all processes so that all processes
325
+ have the averaged results. Returns a dict with the same fields as
326
+ input_dict, after reduction.
327
+ """
328
+ world_size = get_world_size()
329
+ if world_size < 2:
330
+ return input_dict
331
+ with torch.no_grad():
332
+ names = []
333
+ values = []
334
+ # sort the keys so that they are consistent across processes
335
+ for k in sorted(input_dict.keys()):
336
+ names.append(k)
337
+ values.append(input_dict[k])
338
+ values = torch.stack(values, dim=0)
339
+ dist.all_reduce(values)
340
+ if average:
341
+ values /= world_size
342
+ reduced_dict = {k: v for k, v in zip(names, values)}
343
+ return reduced_dict
344
+
345
+
346
+ class MetricLogger(object):
347
+ def __init__(self, delimiter="\t"):
348
+ self.meters = defaultdict(SmoothedValue)
349
+ self.delimiter = delimiter
350
+
351
+ def update(self, **kwargs):
352
+ for k, v in kwargs.items():
353
+ if isinstance(v, torch.Tensor):
354
+ v = v.item()
355
+ assert isinstance(v, (float, int))
356
+ self.meters[k].update(v)
357
+
358
+ def __getattr__(self, attr):
359
+ if attr in self.meters:
360
+ return self.meters[attr]
361
+ if attr in self.__dict__:
362
+ return self.__dict__[attr]
363
+ raise AttributeError(
364
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
365
+ )
366
+
367
+ def __str__(self):
368
+ loss_str = []
369
+ for name, meter in self.meters.items():
370
+ loss_str.append("{}: {}".format(name, str(meter)))
371
+ return self.delimiter.join(loss_str)
372
+
373
+ def synchronize_between_processes(self):
374
+ for meter in self.meters.values():
375
+ meter.synchronize_between_processes()
376
+
377
+ def add_meter(self, name, meter):
378
+ self.meters[name] = meter
379
+
380
+ def log_every(self, iterable, print_freq, header=None):
381
+ i = 0
382
+ if not header:
383
+ header = ""
384
+ start_time = time.time()
385
+ end = time.time()
386
+ iter_time = SmoothedValue(fmt="{avg:.6f}")
387
+ data_time = SmoothedValue(fmt="{avg:.6f}")
388
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
389
+ if torch.cuda.is_available():
390
+ log_msg = self.delimiter.join(
391
+ [
392
+ header,
393
+ "[{0" + space_fmt + "}/{1}]",
394
+ "eta: {eta}",
395
+ "{meters}",
396
+ "time: {time}",
397
+ "data: {data}",
398
+ "max mem: {memory:.0f}",
399
+ ]
400
+ )
401
+ else:
402
+ log_msg = self.delimiter.join(
403
+ [
404
+ header,
405
+ "[{0" + space_fmt + "}/{1}]",
406
+ "eta: {eta}",
407
+ "{meters}",
408
+ "time: {time}",
409
+ "data: {data}",
410
+ ]
411
+ )
412
+ MB = 1024.0 * 1024.0
413
+ for obj in iterable:
414
+ data_time.update(time.time() - end)
415
+ yield obj
416
+ iter_time.update(time.time() - end)
417
+ if i % print_freq == 0 or i == len(iterable) - 1:
418
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
419
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
420
+ if torch.cuda.is_available():
421
+ print(
422
+ log_msg.format(
423
+ i,
424
+ len(iterable),
425
+ eta=eta_string,
426
+ meters=str(self),
427
+ time=str(iter_time),
428
+ data=str(data_time),
429
+ memory=torch.cuda.max_memory_allocated() / MB,
430
+ )
431
+ )
432
+ else:
433
+ print(
434
+ log_msg.format(
435
+ i,
436
+ len(iterable),
437
+ eta=eta_string,
438
+ meters=str(self),
439
+ time=str(iter_time),
440
+ data=str(data_time),
441
+ )
442
+ )
443
+ i += 1
444
+ end = time.time()
445
+ total_time = time.time() - start_time
446
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
447
+ print(
448
+ "{} Total time: {} ({:.6f} s / it)".format(
449
+ header, total_time_str, total_time / len(iterable)
450
+ )
451
+ )
452
+
453
+
454
+ def get_sha():
455
+ cwd = os.path.dirname(os.path.abspath(__file__))
456
+
457
+ def _run(command):
458
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
459
+
460
+ sha = "N/A"
461
+ diff = "clean"
462
+ branch = "N/A"
463
+ try:
464
+ sha = _run(["git", "rev-parse", "HEAD"])
465
+ subprocess.check_output(["git", "diff"], cwd=cwd)
466
+ diff = _run(["git", "diff-index", "HEAD"])
467
+ diff = "has uncommited changes" if diff else "clean"
468
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
469
+ except Exception:
470
+ pass
471
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
472
+ return message
473
+
474
+
475
+ def is_dist_avail_and_initialized():
476
+ if not dist.is_available():
477
+ return False
478
+ if not dist.is_initialized():
479
+ return False
480
+ return True
481
+
482
+
483
+ def get_world_size():
484
+ if not is_dist_avail_and_initialized():
485
+ return 1
486
+ return dist.get_world_size()
487
+
488
+
489
+ def get_rank():
490
+ if not is_dist_avail_and_initialized():
491
+ return 0
492
+ return dist.get_rank()
493
+
494
+
495
+ def is_main_process():
496
+ return get_rank() == 0
497
+
498
+
499
+ def save_on_master(*args, **kwargs):
500
+ if is_main_process():
501
+ torch.save(*args, **kwargs)
502
+
503
+
504
+ def setup_for_distributed(is_master):
505
+ """
506
+ This function disables printing when not in master process
507
+ """
508
+ import builtins as __builtin__
509
+
510
+ builtin_print = __builtin__.print
511
+
512
+ def print(*args, **kwargs):
513
+ force = kwargs.pop("force", False)
514
+ if is_master or force:
515
+ builtin_print(*args, **kwargs)
516
+
517
+ __builtin__.print = print
518
+
519
+
520
+ def init_distributed_mode(args):
521
+ # launched with torch.distributed.launch
522
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
523
+ args.rank = int(os.environ["RANK"])
524
+ args.world_size = int(os.environ["WORLD_SIZE"])
525
+ args.gpu = int(os.environ["LOCAL_RANK"])
526
+ # launched with submitit on a slurm cluster
527
+ elif "SLURM_PROCID" in os.environ:
528
+ args.rank = int(os.environ["SLURM_PROCID"])
529
+ args.gpu = args.rank % torch.cuda.device_count()
530
+ # launched naively with `python main_dino.py`
531
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
532
+ elif torch.cuda.is_available():
533
+ print("Will run the code on one GPU.")
534
+ args.rank, args.gpu, args.world_size = 0, 0, 1
535
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
536
+ os.environ["MASTER_PORT"] = "29500"
537
+ else:
538
+ print("Does not support training without GPU.")
539
+ sys.exit(1)
540
+
541
+ dist.init_process_group(
542
+ backend="nccl",
543
+ init_method=args.dist_url,
544
+ world_size=args.world_size,
545
+ rank=args.rank,
546
+ )
547
+
548
+ torch.cuda.set_device(args.gpu)
549
+ print(
550
+ "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
551
+ )
552
+ dist.barrier()
553
+ setup_for_distributed(args.rank == 0)
554
+
555
+
556
+ def accuracy(output, target, topk=(1,)):
557
+ """Computes the accuracy over the k top predictions for the specified values of k"""
558
+ maxk = max(topk)
559
+ batch_size = target.size(0)
560
+ _, pred = output.topk(maxk, 1, True, True)
561
+ pred = pred.t()
562
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
563
+ return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk]
564
+
565
+
566
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
567
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
568
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
569
+ def norm_cdf(x):
570
+ # Computes standard normal cumulative distribution function
571
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
572
+
573
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
574
+ warnings.warn(
575
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
576
+ "The distribution of values may be incorrect.",
577
+ stacklevel=2,
578
+ )
579
+
580
+ with torch.no_grad():
581
+ # Values are generated by using a truncated uniform distribution and
582
+ # then using the inverse CDF for the normal distribution.
583
+ # Get upper and lower cdf values
584
+ l = norm_cdf((a - mean) / std)
585
+ u = norm_cdf((b - mean) / std)
586
+
587
+ # Uniformly fill tensor with values from [l, u], then translate to
588
+ # [2l-1, 2u-1].
589
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
590
+
591
+ # Use inverse cdf transform for normal distribution to get truncated
592
+ # standard normal
593
+ tensor.erfinv_()
594
+
595
+ # Transform to proper mean, std
596
+ tensor.mul_(std * math.sqrt(2.0))
597
+ tensor.add_(mean)
598
+
599
+ # Clamp to ensure it's in the proper range
600
+ tensor.clamp_(min=a, max=b)
601
+ return tensor
602
+
603
+
604
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
605
+ # type: (Tensor, float, float, float, float) -> Tensor
606
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
607
+
608
+
609
+ class LARS(torch.optim.Optimizer):
610
+ """
611
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
612
+ """
613
+
614
+ def __init__(
615
+ self,
616
+ params,
617
+ lr=0,
618
+ weight_decay=0,
619
+ momentum=0.9,
620
+ eta=0.001,
621
+ weight_decay_filter=None,
622
+ lars_adaptation_filter=None,
623
+ ):
624
+ defaults = dict(
625
+ lr=lr,
626
+ weight_decay=weight_decay,
627
+ momentum=momentum,
628
+ eta=eta,
629
+ weight_decay_filter=weight_decay_filter,
630
+ lars_adaptation_filter=lars_adaptation_filter,
631
+ )
632
+ super().__init__(params, defaults)
633
+
634
+ @torch.no_grad()
635
+ def step(self):
636
+ for g in self.param_groups:
637
+ for p in g["params"]:
638
+ dp = p.grad
639
+
640
+ if dp is None:
641
+ continue
642
+
643
+ if p.ndim != 1:
644
+ dp = dp.add(p, alpha=g["weight_decay"])
645
+
646
+ if p.ndim != 1:
647
+ param_norm = torch.norm(p)
648
+ update_norm = torch.norm(dp)
649
+ one = torch.ones_like(param_norm)
650
+ q = torch.where(
651
+ param_norm > 0.0,
652
+ torch.where(
653
+ update_norm > 0, (g["eta"] * param_norm / update_norm), one
654
+ ),
655
+ one,
656
+ )
657
+ dp = dp.mul(q)
658
+
659
+ param_state = self.state[p]
660
+ if "mu" not in param_state:
661
+ param_state["mu"] = torch.zeros_like(p)
662
+ mu = param_state["mu"]
663
+ mu.mul_(g["momentum"]).add_(dp)
664
+
665
+ p.add_(mu, alpha=-g["lr"])
666
+
667
+
668
+ class MultiCropWrapper(nn.Module):
669
+ """
670
+ Perform forward pass separately on each resolution input.
671
+ The inputs corresponding to a single resolution are clubbed and single
672
+ forward is run on the same resolution inputs. Hence we do several
673
+ forward passes = number of different resolutions used. We then
674
+ concatenate all the output features and run the head forward on these
675
+ concatenated features.
676
+ """
677
+
678
+ def __init__(self, backbone, head):
679
+ super(MultiCropWrapper, self).__init__()
680
+ # disable layers dedicated to ImageNet labels classification
681
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
682
+ self.backbone = backbone
683
+ self.head = head
684
+
685
+ def forward(self, x):
686
+ # convert to list
687
+ if not isinstance(x, list):
688
+ x = [x]
689
+ idx_crops = torch.cumsum(
690
+ torch.unique_consecutive(
691
+ torch.tensor([inp.shape[-1] for inp in x]),
692
+ return_counts=True,
693
+ )[1],
694
+ 0,
695
+ )
696
+ start_idx, output = 0, torch.empty(0).to(x[0].device)
697
+ for end_idx in idx_crops:
698
+ _out = self.backbone(torch.cat(x[start_idx:end_idx]))
699
+ # The output is a tuple with XCiT model. See:
700
+ # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
701
+ if isinstance(_out, tuple):
702
+ _out = _out[0]
703
+ # accumulate outputs
704
+ output = torch.cat((output, _out))
705
+ start_idx = end_idx
706
+ # Run the head forward on the concatenated features.
707
+ return self.head(output)
708
+
709
+
710
+ def get_params_groups(model):
711
+ regularized = []
712
+ not_regularized = []
713
+ for name, param in model.named_parameters():
714
+ if not param.requires_grad:
715
+ continue
716
+ # we do not regularize biases nor Norm parameters
717
+ if name.endswith(".bias") or len(param.shape) == 1:
718
+ not_regularized.append(param)
719
+ else:
720
+ regularized.append(param)
721
+ return [{"params": regularized}, {"params": not_regularized, "weight_decay": 0.0}]
722
+
723
+
724
+ def has_batchnorms(model):
725
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
726
+ for name, module in model.named_modules():
727
+ if isinstance(module, bn_types):
728
+ return True
729
+ return False
730
+
731
+
732
+ class PCA:
733
+ """
734
+ Class to compute and apply PCA.
735
+ """
736
+
737
+ def __init__(self, dim=256, whit=0.5):
738
+ self.dim = dim
739
+ self.whit = whit
740
+ self.mean = None
741
+
742
+ def train_pca(self, cov):
743
+ """
744
+ Takes a covariance matrix (np.ndarray) as input.
745
+ """
746
+ d, v = np.linalg.eigh(cov)
747
+ eps = d.max() * 1e-5
748
+ n_0 = (d < eps).sum()
749
+ if n_0 > 0:
750
+ d[d < eps] = eps
751
+
752
+ # total energy
753
+ totenergy = d.sum()
754
+
755
+ # sort eigenvectors with eigenvalues order
756
+ idx = np.argsort(d)[::-1][: self.dim]
757
+ d = d[idx]
758
+ v = v[:, idx]
759
+
760
+ print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
761
+
762
+ # for the whitening
763
+ d = np.diag(1.0 / d**self.whit)
764
+
765
+ # principal components
766
+ self.dvt = np.dot(d, v.T)
767
+
768
+ def apply(self, x):
769
+ # input is from numpy
770
+ if isinstance(x, np.ndarray):
771
+ if self.mean is not None:
772
+ x -= self.mean
773
+ return np.dot(self.dvt, x.T).T
774
+
775
+ # input is from torch and is on GPU
776
+ if x.is_cuda:
777
+ if self.mean is not None:
778
+ x -= torch.cuda.FloatTensor(self.mean)
779
+ return torch.mm(
780
+ torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)
781
+ ).transpose(0, 1)
782
+
783
+ # input if from torch, on CPU
784
+ if self.mean is not None:
785
+ x -= torch.FloatTensor(self.mean)
786
+ return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
787
+
788
+
789
+ def compute_ap(ranks, nres):
790
+ """
791
+ Computes average precision for given ranked indexes.
792
+ Arguments
793
+ ---------
794
+ ranks : zerro-based ranks of positive images
795
+ nres : number of positive images
796
+ Returns
797
+ -------
798
+ ap : average precision
799
+ """
800
+
801
+ # number of images ranked by the system
802
+ nimgranks = len(ranks)
803
+
804
+ # accumulate trapezoids in PR-plot
805
+ ap = 0
806
+
807
+ recall_step = 1.0 / nres
808
+
809
+ for j in np.arange(nimgranks):
810
+ rank = ranks[j]
811
+
812
+ if rank == 0:
813
+ precision_0 = 1.0
814
+ else:
815
+ precision_0 = float(j) / rank
816
+
817
+ precision_1 = float(j + 1) / (rank + 1)
818
+
819
+ ap += (precision_0 + precision_1) * recall_step / 2.0
820
+
821
+ return ap
822
+
823
+
824
+ def compute_map(ranks, gnd, kappas=[]):
825
+ """
826
+ Computes the mAP for a given set of returned results.
827
+ Usage:
828
+ map = compute_map (ranks, gnd)
829
+ computes mean average precsion (map) only
830
+ map, aps, pr, prs = compute_map (ranks, gnd, kappas)
831
+ computes mean average precision (map), average precision (aps) for each query
832
+ computes mean precision at kappas (pr), precision at kappas (prs) for each query
833
+ Notes:
834
+ 1) ranks starts from 0, ranks.shape = db_size X #queries
835
+ 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
836
+ 3) If there are no positive images for some query, that query is excluded from the evaluation
837
+ """
838
+
839
+ map = 0.0
840
+ nq = len(gnd) # number of queries
841
+ aps = np.zeros(nq)
842
+ pr = np.zeros(len(kappas))
843
+ prs = np.zeros((nq, len(kappas)))
844
+ nempty = 0
845
+
846
+ for i in np.arange(nq):
847
+ qgnd = np.array(gnd[i]["ok"])
848
+
849
+ # no positive images, skip from the average
850
+ if qgnd.shape[0] == 0:
851
+ aps[i] = float("nan")
852
+ prs[i, :] = float("nan")
853
+ nempty += 1
854
+ continue
855
+
856
+ try:
857
+ qgndj = np.array(gnd[i]["junk"])
858
+ except:
859
+ qgndj = np.empty(0)
860
+
861
+ # sorted positions of positive and junk images (0 based)
862
+ pos = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgnd)]
863
+ junk = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgndj)]
864
+
865
+ k = 0
866
+ ij = 0
867
+ if len(junk):
868
+ # decrease positions of positives based on the number of
869
+ # junk images appearing before them
870
+ ip = 0
871
+ while ip < len(pos):
872
+ while ij < len(junk) and pos[ip] > junk[ij]:
873
+ k += 1
874
+ ij += 1
875
+ pos[ip] = pos[ip] - k
876
+ ip += 1
877
+
878
+ # compute ap
879
+ ap = compute_ap(pos, len(qgnd))
880
+ map = map + ap
881
+ aps[i] = ap
882
+
883
+ # compute precision @ k
884
+ pos += 1 # get it to 1-based
885
+ for j in np.arange(len(kappas)):
886
+ kq = min(max(pos), kappas[j])
887
+ prs[i, j] = (pos <= kq).sum() / kq
888
+ pr = pr + prs[i, :]
889
+
890
+ map = map / (nq - nempty)
891
+ pr = pr / (nq - nempty)
892
+
893
+ return map, aps, pr, prs
894
+
895
+
896
+ def multi_scale(samples, model):
897
+ v = None
898
+ for s in [1, 1 / 2 ** (1 / 2), 1 / 2]: # we use 3 different scales
899
+ if s == 1:
900
+ inp = samples.clone()
901
+ else:
902
+ inp = nn.functional.interpolate(
903
+ samples, scale_factor=s, mode="bilinear", align_corners=False
904
+ )
905
+ feats = model(inp).clone()
906
+ if v is None:
907
+ v = feats
908
+ else:
909
+ v += feats
910
+ v /= 3
911
+ v /= v.norm()
912
+ return v
dino/video_generation.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import glob
16
+ import sys
17
+ import argparse
18
+ import cv2
19
+
20
+ from tqdm import tqdm
21
+ import matplotlib.pyplot as plt
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision
25
+ from torchvision import transforms as pth_transforms
26
+ import numpy as np
27
+ from PIL import Image
28
+
29
+ import utils
30
+ import vision_transformer as vits
31
+
32
+
33
+ FOURCC = {
34
+ "mp4": cv2.VideoWriter_fourcc(*"MP4V"),
35
+ "avi": cv2.VideoWriter_fourcc(*"XVID"),
36
+ }
37
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
38
+
39
+
40
+ class VideoGenerator:
41
+ def __init__(self, args):
42
+ self.args = args
43
+ # self.model = None
44
+ # Don't need to load model if you only want a video
45
+ if not self.args.video_only:
46
+ self.model = self.__load_model()
47
+
48
+ def run(self):
49
+ if self.args.input_path is None:
50
+ print(f"Provided input path {self.args.input_path} is non valid.")
51
+ sys.exit(1)
52
+ else:
53
+ if self.args.video_only:
54
+ self._generate_video_from_images(
55
+ self.args.input_path, self.args.output_path
56
+ )
57
+ else:
58
+ # If input path exists
59
+ if os.path.exists(self.args.input_path):
60
+ # If input is a video file
61
+ if os.path.isfile(self.args.input_path):
62
+ frames_folder = os.path.join(self.args.output_path, "frames")
63
+ attention_folder = os.path.join(
64
+ self.args.output_path, "attention"
65
+ )
66
+
67
+ os.makedirs(frames_folder, exist_ok=True)
68
+ os.makedirs(attention_folder, exist_ok=True)
69
+
70
+ self._extract_frames_from_video(
71
+ self.args.input_path, frames_folder
72
+ )
73
+
74
+ self._inference(
75
+ frames_folder,
76
+ attention_folder,
77
+ )
78
+
79
+ self._generate_video_from_images(
80
+ attention_folder, self.args.output_path
81
+ )
82
+
83
+ # If input is a folder of already extracted frames
84
+ if os.path.isdir(self.args.input_path):
85
+ attention_folder = os.path.join(
86
+ self.args.output_path, "attention"
87
+ )
88
+
89
+ os.makedirs(attention_folder, exist_ok=True)
90
+
91
+ self._inference(self.args.input_path, attention_folder)
92
+
93
+ self._generate_video_from_images(
94
+ attention_folder, self.args.output_path
95
+ )
96
+
97
+ # If input path doesn't exists
98
+ else:
99
+ print(f"Provided input path {self.args.input_path} doesn't exists.")
100
+ sys.exit(1)
101
+
102
+ def _extract_frames_from_video(self, inp: str, out: str):
103
+ vidcap = cv2.VideoCapture(inp)
104
+ self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)
105
+
106
+ print(f"Video: {inp} ({self.args.fps} fps)")
107
+ print(f"Extracting frames to {out}")
108
+
109
+ success, image = vidcap.read()
110
+ count = 0
111
+ while success:
112
+ cv2.imwrite(
113
+ os.path.join(out, f"frame-{count:04}.jpg"),
114
+ image,
115
+ )
116
+ success, image = vidcap.read()
117
+ count += 1
118
+
119
+ def _generate_video_from_images(self, inp: str, out: str):
120
+ img_array = []
121
+ attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))
122
+
123
+ # Get size of the first image
124
+ with open(attention_images_list[0], "rb") as f:
125
+ img = Image.open(f)
126
+ img = img.convert("RGB")
127
+ size = (img.width, img.height)
128
+ img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
129
+
130
+ print(f"Generating video {size} to {out}")
131
+
132
+ for filename in tqdm(attention_images_list[1:]):
133
+ with open(filename, "rb") as f:
134
+ img = Image.open(f)
135
+ img = img.convert("RGB")
136
+ img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
137
+
138
+ out = cv2.VideoWriter(
139
+ os.path.join(out, "video." + self.args.video_format),
140
+ FOURCC[self.args.video_format],
141
+ self.args.fps,
142
+ size,
143
+ )
144
+
145
+ for i in range(len(img_array)):
146
+ out.write(img_array[i])
147
+ out.release()
148
+ print("Done")
149
+
150
+ def _inference(self, inp: str, out: str):
151
+ print(f"Generating attention images to {out}")
152
+
153
+ for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
154
+ with open(img_path, "rb") as f:
155
+ img = Image.open(f)
156
+ img = img.convert("RGB")
157
+
158
+ if self.args.resize is not None:
159
+ transform = pth_transforms.Compose(
160
+ [
161
+ pth_transforms.ToTensor(),
162
+ pth_transforms.Resize(self.args.resize),
163
+ pth_transforms.Normalize(
164
+ (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
165
+ ),
166
+ ]
167
+ )
168
+ else:
169
+ transform = pth_transforms.Compose(
170
+ [
171
+ pth_transforms.ToTensor(),
172
+ pth_transforms.Normalize(
173
+ (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
174
+ ),
175
+ ]
176
+ )
177
+
178
+ img = transform(img)
179
+
180
+ # make the image divisible by the patch size
181
+ w, h = (
182
+ img.shape[1] - img.shape[1] % self.args.patch_size,
183
+ img.shape[2] - img.shape[2] % self.args.patch_size,
184
+ )
185
+ img = img[:, :w, :h].unsqueeze(0)
186
+
187
+ w_featmap = img.shape[-2] // self.args.patch_size
188
+ h_featmap = img.shape[-1] // self.args.patch_size
189
+
190
+ attentions = self.model.get_last_selfattention(img.to(DEVICE))
191
+
192
+ nh = attentions.shape[1] # number of head
193
+
194
+ # we keep only the output patch attention
195
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
196
+
197
+ # we keep only a certain percentage of the mass
198
+ val, idx = torch.sort(attentions)
199
+ val /= torch.sum(val, dim=1, keepdim=True)
200
+ cumval = torch.cumsum(val, dim=1)
201
+ th_attn = cumval > (1 - self.args.threshold)
202
+ idx2 = torch.argsort(idx)
203
+ for head in range(nh):
204
+ th_attn[head] = th_attn[head][idx2[head]]
205
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
206
+ # interpolate
207
+ th_attn = (
208
+ nn.functional.interpolate(
209
+ th_attn.unsqueeze(0),
210
+ scale_factor=self.args.patch_size,
211
+ mode="nearest",
212
+ )[0]
213
+ .cpu()
214
+ .numpy()
215
+ )
216
+
217
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
218
+ attentions = (
219
+ nn.functional.interpolate(
220
+ attentions.unsqueeze(0),
221
+ scale_factor=self.args.patch_size,
222
+ mode="nearest",
223
+ )[0]
224
+ .cpu()
225
+ .numpy()
226
+ )
227
+
228
+ # save attentions heatmaps
229
+ fname = os.path.join(out, "attn-" + os.path.basename(img_path))
230
+ plt.imsave(
231
+ fname=fname,
232
+ arr=sum(
233
+ attentions[i] * 1 / attentions.shape[0]
234
+ for i in range(attentions.shape[0])
235
+ ),
236
+ cmap="inferno",
237
+ format="jpg",
238
+ )
239
+
240
+ def __load_model(self):
241
+ # build model
242
+ model = vits.__dict__[self.args.arch](
243
+ patch_size=self.args.patch_size, num_classes=0
244
+ )
245
+ for p in model.parameters():
246
+ p.requires_grad = False
247
+ model.eval()
248
+ model.to(DEVICE)
249
+
250
+ if os.path.isfile(self.args.pretrained_weights):
251
+ state_dict = torch.load(self.args.pretrained_weights, map_location="cpu")
252
+ if (
253
+ self.args.checkpoint_key is not None
254
+ and self.args.checkpoint_key in state_dict
255
+ ):
256
+ print(
257
+ f"Take key {self.args.checkpoint_key} in provided checkpoint dict"
258
+ )
259
+ state_dict = state_dict[self.args.checkpoint_key]
260
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
261
+ # remove `backbone.` prefix induced by multicrop wrapper
262
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
263
+ msg = model.load_state_dict(state_dict, strict=False)
264
+ print(
265
+ "Pretrained weights found at {} and loaded with msg: {}".format(
266
+ self.args.pretrained_weights, msg
267
+ )
268
+ )
269
+ else:
270
+ print(
271
+ "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
272
+ )
273
+ url = None
274
+ if self.args.arch == "vit_small" and self.args.patch_size == 16:
275
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
276
+ elif self.args.arch == "vit_small" and self.args.patch_size == 8:
277
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
278
+ elif self.args.arch == "vit_base" and self.args.patch_size == 16:
279
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
280
+ elif self.args.arch == "vit_base" and self.args.patch_size == 8:
281
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
282
+ if url is not None:
283
+ print(
284
+ "Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
285
+ )
286
+ state_dict = torch.hub.load_state_dict_from_url(
287
+ url="https://dl.fbaipublicfiles.com/dino/" + url
288
+ )
289
+ model.load_state_dict(state_dict, strict=True)
290
+ else:
291
+ print(
292
+ "There is no reference weights available for this model => We use random weights."
293
+ )
294
+ return model
295
+
296
+
297
+ def parse_args():
298
+ parser = argparse.ArgumentParser("Generation self-attention video")
299
+ parser.add_argument(
300
+ "--arch",
301
+ default="vit_small",
302
+ type=str,
303
+ choices=["vit_tiny", "vit_small", "vit_base"],
304
+ help="Architecture (support only ViT atm).",
305
+ )
306
+ parser.add_argument(
307
+ "--patch_size", default=8, type=int, help="Patch resolution of the self.model."
308
+ )
309
+ parser.add_argument(
310
+ "--pretrained_weights",
311
+ default="",
312
+ type=str,
313
+ help="Path to pretrained weights to load.",
314
+ )
315
+ parser.add_argument(
316
+ "--checkpoint_key",
317
+ default="teacher",
318
+ type=str,
319
+ help='Key to use in the checkpoint (example: "teacher")',
320
+ )
321
+ parser.add_argument(
322
+ "--input_path",
323
+ required=True,
324
+ type=str,
325
+ help="""Path to a video file if you want to extract frames
326
+ or to a folder of images already extracted by yourself.
327
+ or to a folder of attention images.""",
328
+ )
329
+ parser.add_argument(
330
+ "--output_path",
331
+ default="./",
332
+ type=str,
333
+ help="""Path to store a folder of frames and / or a folder of attention images.
334
+ and / or a final video. Default to current directory.""",
335
+ )
336
+ parser.add_argument(
337
+ "--threshold",
338
+ type=float,
339
+ default=0.6,
340
+ help="""We visualize masks
341
+ obtained by thresholding the self-attention maps to keep xx percent of the mass.""",
342
+ )
343
+ parser.add_argument(
344
+ "--resize",
345
+ default=None,
346
+ type=int,
347
+ nargs="+",
348
+ help="""Apply a resize transformation to input image(s). Use if OOM error.
349
+ Usage (single or W H): --resize 512, --resize 720 1280""",
350
+ )
351
+ parser.add_argument(
352
+ "--video_only",
353
+ action="store_true",
354
+ help="""Use this flag if you only want to generate a video and not all attention images.
355
+ If used, --input_path must be set to the folder of attention images. Ex: ./attention/""",
356
+ )
357
+ parser.add_argument(
358
+ "--fps",
359
+ default=30.0,
360
+ type=float,
361
+ help="FPS of input / output video. Automatically set if you extract frames from a video.",
362
+ )
363
+ parser.add_argument(
364
+ "--video_format",
365
+ default="mp4",
366
+ type=str,
367
+ choices=["mp4", "avi"],
368
+ help="Format of generated video (mp4 or avi).",
369
+ )
370
+
371
+ return parser.parse_args()
372
+
373
+
374
+ if __name__ == "__main__":
375
+ args = parse_args()
376
+
377
+ vg = VideoGenerator(args)
378
+ vg.run()
dino/vision_transformer.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+ from utils import trunc_normal_
25
+
26
+
27
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
28
+ if drop_prob == 0.0 or not training:
29
+ return x
30
+ keep_prob = 1 - drop_prob
31
+ shape = (x.shape[0],) + (1,) * (
32
+ x.ndim - 1
33
+ ) # work with diff dim tensors, not just 2D ConvNets
34
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
35
+ random_tensor.floor_() # binarize
36
+ output = x.div(keep_prob) * random_tensor
37
+ return output
38
+
39
+
40
+ class DropPath(nn.Module):
41
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
42
+
43
+ def __init__(self, drop_prob=None):
44
+ super(DropPath, self).__init__()
45
+ self.drop_prob = drop_prob
46
+
47
+ def forward(self, x):
48
+ return drop_path(x, self.drop_prob, self.training)
49
+
50
+
51
+ class Mlp(nn.Module):
52
+ def __init__(
53
+ self,
54
+ in_features,
55
+ hidden_features=None,
56
+ out_features=None,
57
+ act_layer=nn.GELU,
58
+ drop=0.0,
59
+ ):
60
+ super().__init__()
61
+ out_features = out_features or in_features
62
+ hidden_features = hidden_features or in_features
63
+ self.fc1 = nn.Linear(in_features, hidden_features)
64
+ self.act = act_layer()
65
+ self.fc2 = nn.Linear(hidden_features, out_features)
66
+ self.drop = nn.Dropout(drop)
67
+
68
+ def forward(self, x):
69
+ x = self.fc1(x)
70
+ x = self.act(x)
71
+ x = self.drop(x)
72
+ x = self.fc2(x)
73
+ x = self.drop(x)
74
+ return x
75
+
76
+
77
+ class Attention(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim,
81
+ num_heads=8,
82
+ qkv_bias=False,
83
+ qk_scale=None,
84
+ attn_drop=0.0,
85
+ proj_drop=0.0,
86
+ ):
87
+ super().__init__()
88
+ self.num_heads = num_heads
89
+ head_dim = dim // num_heads
90
+ self.scale = qk_scale or head_dim**-0.5
91
+
92
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
93
+ self.attn_drop = nn.Dropout(attn_drop)
94
+ self.proj = nn.Linear(dim, dim)
95
+ self.proj_drop = nn.Dropout(proj_drop)
96
+
97
+ def forward(self, x):
98
+ B, N, C = x.shape
99
+ qkv = (
100
+ self.qkv(x)
101
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
102
+ .permute(2, 0, 3, 1, 4)
103
+ )
104
+ q, k, v = qkv[0], qkv[1], qkv[2]
105
+
106
+ attn = (q @ k.transpose(-2, -1)) * self.scale
107
+ attn = attn.softmax(dim=-1)
108
+ attn = self.attn_drop(attn)
109
+
110
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
111
+ x = self.proj(x)
112
+ x = self.proj_drop(x)
113
+ return x, attn
114
+
115
+
116
+ class Block(nn.Module):
117
+ def __init__(
118
+ self,
119
+ dim,
120
+ num_heads,
121
+ mlp_ratio=4.0,
122
+ qkv_bias=False,
123
+ qk_scale=None,
124
+ drop=0.0,
125
+ attn_drop=0.0,
126
+ drop_path=0.0,
127
+ act_layer=nn.GELU,
128
+ norm_layer=nn.LayerNorm,
129
+ ):
130
+ super().__init__()
131
+ self.norm1 = norm_layer(dim)
132
+ self.attn = Attention(
133
+ dim,
134
+ num_heads=num_heads,
135
+ qkv_bias=qkv_bias,
136
+ qk_scale=qk_scale,
137
+ attn_drop=attn_drop,
138
+ proj_drop=drop,
139
+ )
140
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
141
+ self.norm2 = norm_layer(dim)
142
+ mlp_hidden_dim = int(dim * mlp_ratio)
143
+ self.mlp = Mlp(
144
+ in_features=dim,
145
+ hidden_features=mlp_hidden_dim,
146
+ act_layer=act_layer,
147
+ drop=drop,
148
+ )
149
+
150
+ def forward(self, x, return_attention=False):
151
+ y, attn = self.attn(self.norm1(x))
152
+ if return_attention:
153
+ return attn
154
+ x = x + self.drop_path(y)
155
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
156
+ return x
157
+
158
+
159
+ class PatchEmbed(nn.Module):
160
+ """Image to Patch Embedding"""
161
+
162
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
163
+ super().__init__()
164
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
165
+ self.img_size = img_size
166
+ self.patch_size = patch_size
167
+ self.num_patches = num_patches
168
+
169
+ self.proj = nn.Conv2d(
170
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
171
+ )
172
+
173
+ def forward(self, x):
174
+ B, C, H, W = x.shape
175
+ x = self.proj(x).flatten(2).transpose(1, 2)
176
+ return x
177
+
178
+
179
+ class VisionTransformer(nn.Module):
180
+ """Vision Transformer"""
181
+
182
+ def __init__(
183
+ self,
184
+ img_size=[224],
185
+ patch_size=16,
186
+ in_chans=3,
187
+ num_classes=0,
188
+ embed_dim=768,
189
+ depth=12,
190
+ num_heads=12,
191
+ mlp_ratio=4.0,
192
+ qkv_bias=False,
193
+ qk_scale=None,
194
+ drop_rate=0.0,
195
+ attn_drop_rate=0.0,
196
+ drop_path_rate=0.0,
197
+ norm_layer=nn.LayerNorm,
198
+ **kwargs
199
+ ):
200
+ super().__init__()
201
+ self.num_features = self.embed_dim = embed_dim
202
+
203
+ self.patch_embed = PatchEmbed(
204
+ img_size=img_size[0],
205
+ patch_size=patch_size,
206
+ in_chans=in_chans,
207
+ embed_dim=embed_dim,
208
+ )
209
+ num_patches = self.patch_embed.num_patches
210
+
211
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
212
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
213
+ self.pos_drop = nn.Dropout(p=drop_rate)
214
+
215
+ dpr = [
216
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
217
+ ] # stochastic depth decay rule
218
+ self.blocks = nn.ModuleList(
219
+ [
220
+ Block(
221
+ dim=embed_dim,
222
+ num_heads=num_heads,
223
+ mlp_ratio=mlp_ratio,
224
+ qkv_bias=qkv_bias,
225
+ qk_scale=qk_scale,
226
+ drop=drop_rate,
227
+ attn_drop=attn_drop_rate,
228
+ drop_path=dpr[i],
229
+ norm_layer=norm_layer,
230
+ )
231
+ for i in range(depth)
232
+ ]
233
+ )
234
+ self.norm = norm_layer(embed_dim)
235
+
236
+ # Classifier head
237
+ self.head = (
238
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
239
+ )
240
+
241
+ trunc_normal_(self.pos_embed, std=0.02)
242
+ trunc_normal_(self.cls_token, std=0.02)
243
+ self.apply(self._init_weights)
244
+
245
+ def _init_weights(self, m):
246
+ if isinstance(m, nn.Linear):
247
+ trunc_normal_(m.weight, std=0.02)
248
+ if isinstance(m, nn.Linear) and m.bias is not None:
249
+ nn.init.constant_(m.bias, 0)
250
+ elif isinstance(m, nn.LayerNorm):
251
+ nn.init.constant_(m.bias, 0)
252
+ nn.init.constant_(m.weight, 1.0)
253
+
254
+ def interpolate_pos_encoding(self, x, w, h):
255
+ npatch = x.shape[1] - 1
256
+ N = self.pos_embed.shape[1] - 1
257
+ if npatch == N and w == h:
258
+ return self.pos_embed
259
+ class_pos_embed = self.pos_embed[:, 0]
260
+ patch_pos_embed = self.pos_embed[:, 1:]
261
+ dim = x.shape[-1]
262
+ w0 = w // self.patch_embed.patch_size
263
+ h0 = h // self.patch_embed.patch_size
264
+ # we add a small number to avoid floating point error in the interpolation
265
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
266
+ w0, h0 = w0 + 0.1, h0 + 0.1
267
+ patch_pos_embed = nn.functional.interpolate(
268
+ patch_pos_embed.reshape(
269
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
270
+ ).permute(0, 3, 1, 2),
271
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
272
+ mode="bicubic",
273
+ )
274
+ assert (
275
+ int(w0) == patch_pos_embed.shape[-2]
276
+ and int(h0) == patch_pos_embed.shape[-1]
277
+ )
278
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
279
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
280
+
281
+ def prepare_tokens(self, x):
282
+ B, nc, w, h = x.shape
283
+ x = self.patch_embed(x) # patch linear embedding
284
+
285
+ # add the [CLS] token to the embed patch tokens
286
+ cls_tokens = self.cls_token.expand(B, -1, -1)
287
+ x = torch.cat((cls_tokens, x), dim=1)
288
+
289
+ # add positional encoding to each token
290
+ x = x + self.interpolate_pos_encoding(x, w, h)
291
+
292
+ return self.pos_drop(x)
293
+
294
+ def forward(self, x):
295
+ x = self.prepare_tokens(x)
296
+ for blk in self.blocks:
297
+ x = blk(x)
298
+ x = self.norm(x)
299
+ return x[:, 0]
300
+
301
+ def get_last_selfattention(self, x):
302
+ x = self.prepare_tokens(x)
303
+ for i, blk in enumerate(self.blocks):
304
+ if i < len(self.blocks) - 1:
305
+ x = blk(x)
306
+ else:
307
+ # return attention of the last block
308
+ return blk(x, return_attention=True)
309
+
310
+ def get_intermediate_layers(self, x, n=1):
311
+ x = self.prepare_tokens(x)
312
+ # we return the output tokens from the `n` last blocks
313
+ output = []
314
+ for i, blk in enumerate(self.blocks):
315
+ x = blk(x)
316
+ if len(self.blocks) - i <= n:
317
+ output.append(self.norm(x))
318
+ return output
319
+
320
+
321
+ def vit_tiny(patch_size=16, **kwargs):
322
+ model = VisionTransformer(
323
+ patch_size=patch_size,
324
+ embed_dim=192,
325
+ depth=12,
326
+ num_heads=3,
327
+ mlp_ratio=4,
328
+ qkv_bias=True,
329
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
330
+ **kwargs
331
+ )
332
+ return model
333
+
334
+
335
+ def vit_small(patch_size=16, **kwargs):
336
+ model = VisionTransformer(
337
+ patch_size=patch_size,
338
+ embed_dim=384,
339
+ depth=12,
340
+ num_heads=6,
341
+ mlp_ratio=4,
342
+ qkv_bias=True,
343
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
344
+ **kwargs
345
+ )
346
+ return model
347
+
348
+
349
+ def vit_base(patch_size=16, **kwargs):
350
+ model = VisionTransformer(
351
+ patch_size=patch_size,
352
+ embed_dim=768,
353
+ depth=12,
354
+ num_heads=12,
355
+ mlp_ratio=4,
356
+ qkv_bias=True,
357
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
358
+ **kwargs
359
+ )
360
+ return model
361
+
362
+
363
+ class DINOHead(nn.Module):
364
+ def __init__(
365
+ self,
366
+ in_dim,
367
+ out_dim,
368
+ use_bn=False,
369
+ norm_last_layer=True,
370
+ nlayers=3,
371
+ hidden_dim=2048,
372
+ bottleneck_dim=256,
373
+ ):
374
+ super().__init__()
375
+ nlayers = max(nlayers, 1)
376
+ if nlayers == 1:
377
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
378
+ else:
379
+ layers = [nn.Linear(in_dim, hidden_dim)]
380
+ if use_bn:
381
+ layers.append(nn.BatchNorm1d(hidden_dim))
382
+ layers.append(nn.GELU())
383
+ for _ in range(nlayers - 2):
384
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
385
+ if use_bn:
386
+ layers.append(nn.BatchNorm1d(hidden_dim))
387
+ layers.append(nn.GELU())
388
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
389
+ self.mlp = nn.Sequential(*layers)
390
+ self.apply(self._init_weights)
391
+ self.last_layer = nn.utils.weight_norm(
392
+ nn.Linear(bottleneck_dim, out_dim, bias=False)
393
+ )
394
+ self.last_layer.weight_g.data.fill_(1)
395
+ if norm_last_layer:
396
+ self.last_layer.weight_g.requires_grad = False
397
+
398
+ def _init_weights(self, m):
399
+ if isinstance(m, nn.Linear):
400
+ trunc_normal_(m.weight, std=0.02)
401
+ if isinstance(m, nn.Linear) and m.bias is not None:
402
+ nn.init.constant_(m.bias, 0)
403
+
404
+ def forward(self, x):
405
+ x = self.mlp(x)
406
+ x = nn.functional.normalize(x, dim=-1, p=2)
407
+ x = self.last_layer(x)
408
+ return x
dino/visualize_attention.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import argparse
17
+ import cv2
18
+ import random
19
+ import colorsys
20
+ import requests
21
+ from io import BytesIO
22
+
23
+ import skimage.io
24
+ from skimage.measure import find_contours
25
+ import matplotlib.pyplot as plt
26
+ from matplotlib.patches import Polygon
27
+ import torch
28
+ import torch.nn as nn
29
+ import torchvision
30
+ from torchvision import transforms as pth_transforms
31
+ import numpy as np
32
+ from PIL import Image
33
+
34
+ import utils
35
+ import vision_transformer as vits
36
+
37
+
38
+ def apply_mask(image, mask, color, alpha=0.5):
39
+ for c in range(3):
40
+ image[:, :, c] = (
41
+ image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
42
+ )
43
+ return image
44
+
45
+
46
+ def random_colors(N, bright=True):
47
+ """
48
+ Generate random colors.
49
+ """
50
+ brightness = 1.0 if bright else 0.7
51
+ hsv = [(i / N, 1, brightness) for i in range(N)]
52
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
53
+ random.shuffle(colors)
54
+ return colors
55
+
56
+
57
+ def display_instances(
58
+ image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5
59
+ ):
60
+ fig = plt.figure(figsize=figsize, frameon=False)
61
+ ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
62
+ ax.set_axis_off()
63
+ fig.add_axes(ax)
64
+ ax = plt.gca()
65
+
66
+ N = 1
67
+ mask = mask[None, :, :]
68
+ # Generate random colors
69
+ colors = random_colors(N)
70
+
71
+ # Show area outside image boundaries.
72
+ height, width = image.shape[:2]
73
+ margin = 0
74
+ ax.set_ylim(height + margin, -margin)
75
+ ax.set_xlim(-margin, width + margin)
76
+ ax.axis("off")
77
+ masked_image = image.astype(np.uint32).copy()
78
+ for i in range(N):
79
+ color = colors[i]
80
+ _mask = mask[i]
81
+ if blur:
82
+ _mask = cv2.blur(_mask, (10, 10))
83
+ # Mask
84
+ masked_image = apply_mask(masked_image, _mask, color, alpha)
85
+ # Mask Polygon
86
+ # Pad to ensure proper polygons for masks that touch image edges.
87
+ if contour:
88
+ padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
89
+ padded_mask[1:-1, 1:-1] = _mask
90
+ contours = find_contours(padded_mask, 0.5)
91
+ for verts in contours:
92
+ # Subtract the padding and flip (y, x) to (x, y)
93
+ verts = np.fliplr(verts) - 1
94
+ p = Polygon(verts, facecolor="none", edgecolor=color)
95
+ ax.add_patch(p)
96
+ ax.imshow(masked_image.astype(np.uint8), aspect="auto")
97
+ fig.savefig(fname)
98
+ print(f"{fname} saved.")
99
+ return
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser("Visualize Self-Attention maps")
104
+ parser.add_argument(
105
+ "--arch",
106
+ default="vit_small",
107
+ type=str,
108
+ choices=["vit_tiny", "vit_small", "vit_base"],
109
+ help="Architecture (support only ViT atm).",
110
+ )
111
+ parser.add_argument(
112
+ "--patch_size", default=8, type=int, help="Patch resolution of the model."
113
+ )
114
+ parser.add_argument(
115
+ "--pretrained_weights",
116
+ default="",
117
+ type=str,
118
+ help="Path to pretrained weights to load.",
119
+ )
120
+ parser.add_argument(
121
+ "--checkpoint_key",
122
+ default="teacher",
123
+ type=str,
124
+ help='Key to use in the checkpoint (example: "teacher")',
125
+ )
126
+ parser.add_argument(
127
+ "--image_path", default=None, type=str, help="Path of the image to load."
128
+ )
129
+ parser.add_argument(
130
+ "--image_size", default=(480, 480), type=int, nargs="+", help="Resize image."
131
+ )
132
+ parser.add_argument(
133
+ "--output_dir", default=".", help="Path where to save visualizations."
134
+ )
135
+ parser.add_argument(
136
+ "--threshold",
137
+ type=float,
138
+ default=None,
139
+ help="""We visualize masks
140
+ obtained by thresholding the self-attention maps to keep xx% of the mass.""",
141
+ )
142
+ args = parser.parse_args()
143
+
144
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
145
+ # build model
146
+ model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
147
+ for p in model.parameters():
148
+ p.requires_grad = False
149
+ model.eval()
150
+ model.to(device)
151
+ if os.path.isfile(args.pretrained_weights):
152
+ state_dict = torch.load(args.pretrained_weights, map_location="cpu")
153
+ if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
154
+ print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
155
+ state_dict = state_dict[args.checkpoint_key]
156
+ # remove `module.` prefix
157
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
158
+ # remove `backbone.` prefix induced by multicrop wrapper
159
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
160
+ msg = model.load_state_dict(state_dict, strict=False)
161
+ print(
162
+ "Pretrained weights found at {} and loaded with msg: {}".format(
163
+ args.pretrained_weights, msg
164
+ )
165
+ )
166
+ else:
167
+ print(
168
+ "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
169
+ )
170
+ url = None
171
+ if args.arch == "vit_small" and args.patch_size == 16:
172
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
173
+ elif args.arch == "vit_small" and args.patch_size == 8:
174
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
175
+ elif args.arch == "vit_base" and args.patch_size == 16:
176
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
177
+ elif args.arch == "vit_base" and args.patch_size == 8:
178
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
179
+ if url is not None:
180
+ print(
181
+ "Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
182
+ )
183
+ state_dict = torch.hub.load_state_dict_from_url(
184
+ url="https://dl.fbaipublicfiles.com/dino/" + url
185
+ )
186
+ model.load_state_dict(state_dict, strict=True)
187
+ else:
188
+ print(
189
+ "There is no reference weights available for this model => We use random weights."
190
+ )
191
+
192
+ # open image
193
+ if args.image_path is None:
194
+ # user has not specified any image - we use our own image
195
+ print(
196
+ "Please use the `--image_path` argument to indicate the path of the image you wish to visualize."
197
+ )
198
+ print(
199
+ "Since no image path have been provided, we take the first image in our paper."
200
+ )
201
+ response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
202
+ img = Image.open(BytesIO(response.content))
203
+ img = img.convert("RGB")
204
+ elif os.path.isfile(args.image_path):
205
+ with open(args.image_path, "rb") as f:
206
+ img = Image.open(f)
207
+ img = img.convert("RGB")
208
+ else:
209
+ print(f"Provided image path {args.image_path} is non valid.")
210
+ sys.exit(1)
211
+ transform = pth_transforms.Compose(
212
+ [
213
+ pth_transforms.Resize(args.image_size),
214
+ pth_transforms.ToTensor(),
215
+ pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
216
+ ]
217
+ )
218
+ img = transform(img)
219
+
220
+ # make the image divisible by the patch size
221
+ w, h = (
222
+ img.shape[1] - img.shape[1] % args.patch_size,
223
+ img.shape[2] - img.shape[2] % args.patch_size,
224
+ )
225
+ img = img[:, :w, :h].unsqueeze(0)
226
+
227
+ w_featmap = img.shape[-2] // args.patch_size
228
+ h_featmap = img.shape[-1] // args.patch_size
229
+
230
+ attentions = model.get_last_selfattention(img.to(device))
231
+
232
+ nh = attentions.shape[1] # number of head
233
+
234
+ # we keep only the output patch attention
235
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
236
+
237
+ if args.threshold is not None:
238
+ # we keep only a certain percentage of the mass
239
+ val, idx = torch.sort(attentions)
240
+ val /= torch.sum(val, dim=1, keepdim=True)
241
+ cumval = torch.cumsum(val, dim=1)
242
+ th_attn = cumval > (1 - args.threshold)
243
+ idx2 = torch.argsort(idx)
244
+ for head in range(nh):
245
+ th_attn[head] = th_attn[head][idx2[head]]
246
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
247
+ # interpolate
248
+ th_attn = (
249
+ nn.functional.interpolate(
250
+ th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
251
+ )[0]
252
+ .cpu()
253
+ .numpy()
254
+ )
255
+
256
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
257
+ attentions = (
258
+ nn.functional.interpolate(
259
+ attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
260
+ )[0]
261
+ .cpu()
262
+ .numpy()
263
+ )
264
+
265
+ # save attentions heatmaps
266
+ os.makedirs(args.output_dir, exist_ok=True)
267
+ torchvision.utils.save_image(
268
+ torchvision.utils.make_grid(img, normalize=True, scale_each=True),
269
+ os.path.join(args.output_dir, "img.png"),
270
+ )
271
+ for j in range(nh):
272
+ fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
273
+ plt.imsave(fname=fname, arr=attentions[j], format="png")
274
+ print(f"{fname} saved.")
275
+
276
+ if args.threshold is not None:
277
+ image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
278
+ for j in range(nh):
279
+ display_instances(
280
+ image,
281
+ th_attn[j],
282
+ fname=os.path.join(
283
+ args.output_dir,
284
+ "mask_th" + str(args.threshold) + "_head" + str(j) + ".png",
285
+ ),
286
+ blur=False,
287
+ )