Spaces:
Running
Running
hasibzunair
commited on
Commit
•
2895c00
1
Parent(s):
c16c4c8
add dino
Browse files- .DS_Store +0 -0
- dino/README.md +411 -0
- dino/__init__.py +4 -0
- dino/eval_copy_detection.py +381 -0
- dino/eval_image_retrieval.py +274 -0
- dino/eval_knn.py +325 -0
- dino/eval_linear.py +421 -0
- dino/eval_video_segmentation.py +372 -0
- dino/hubconf.py +159 -0
- dino/main_dino.py +689 -0
- dino/run_with_submitit.py +148 -0
- dino/utils.py +912 -0
- dino/video_generation.py +378 -0
- dino/vision_transformer.py +408 -0
- dino/visualize_attention.py +287 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
dino/README.md
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
:new: *Please check out our more recent [DINOv2](https://github.com/facebookresearch/dinov2) effort in the same line of work.*
|
2 |
+
|
3 |
+
# Self-Supervised Vision Transformers with DINO
|
4 |
+
|
5 |
+
PyTorch implementation and pretrained models for DINO. For details, see **Emerging Properties in Self-Supervised Vision Transformers**.
|
6 |
+
[[`blogpost`](https://ai.facebook.com/blog/dino-paws-computer-vision-with-self-supervised-transformers-and-10x-more-efficient-training)] [[`arXiv`](https://arxiv.org/abs/2104.14294)] [[`Yannic Kilcher's video`](https://www.youtube.com/watch?v=h3ij3F3cPIk)]
|
7 |
+
|
8 |
+
<div align="center">
|
9 |
+
<img width="100%" alt="DINO illustration" src=".github/dino.gif">
|
10 |
+
</div>
|
11 |
+
|
12 |
+
## Pretrained models
|
13 |
+
You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the backbone in `onnx` format, as well as detailed arguments and training/evaluation logs. Note that `DeiT-S` and `ViT-S` names refer exactly to the same architecture.
|
14 |
+
|
15 |
+
<table>
|
16 |
+
<tr>
|
17 |
+
<th>arch</th>
|
18 |
+
<th>params</th>
|
19 |
+
<th>k-nn</th>
|
20 |
+
<th>linear</th>
|
21 |
+
<th colspan="6">download</th>
|
22 |
+
</tr>
|
23 |
+
<tr>
|
24 |
+
<td>ViT-S/16</td>
|
25 |
+
<td>21M</td>
|
26 |
+
<td>74.5%</td>
|
27 |
+
<td>77.0%</td>
|
28 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth">backbone only</a></td>
|
29 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
30 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deits16.onnx">onnx</a></td>
|
31 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/args.txt">args</a></td>
|
32 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_log.txt">logs</a></td>
|
33 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_eval_linear_log.txt">eval logs</a></td>
|
34 |
+
</tr>
|
35 |
+
<tr>
|
36 |
+
<td>ViT-S/8</td>
|
37 |
+
<td>21M</td>
|
38 |
+
<td>78.3%</td>
|
39 |
+
<td>79.7%</td>
|
40 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth">backbone only</a></td>
|
41 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
42 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deits8.onnx">onnx</a></td>
|
43 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/args.txt">args</a></td>
|
44 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_log.txt">logs</a></td>
|
45 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_eval_linear_log.txt">eval logs</a></td>
|
46 |
+
</tr>
|
47 |
+
<tr>
|
48 |
+
<td>ViT-B/16</td>
|
49 |
+
<td>85M</td>
|
50 |
+
<td>76.1%</td>
|
51 |
+
<td>78.2%</td>
|
52 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth">backbone only</a></td>
|
53 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
54 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitb16.onnx">onnx</a></td>
|
55 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/args.txt">args</a></td>
|
56 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_log.txt">logs</a></td>
|
57 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_eval_linear_log.txt">eval logs</a></td>
|
58 |
+
</tr>
|
59 |
+
<tr>
|
60 |
+
<td>ViT-B/8</td>
|
61 |
+
<td>85M</td>
|
62 |
+
<td>77.4%</td>
|
63 |
+
<td>80.1%</td>
|
64 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth">backbone only</a></td>
|
65 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
66 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitb8.onnx">onnx</a></td>
|
67 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/args.txt">args</a></td>
|
68 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_log.txt">logs</a></td>
|
69 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_eval_linear_log.txt">eval logs</a></td>
|
70 |
+
</tr>
|
71 |
+
<tr>
|
72 |
+
<td>ResNet-50</td>
|
73 |
+
<td>23M</td>
|
74 |
+
<td>67.5%</td>
|
75 |
+
<td>75.3%</td>
|
76 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth">backbone only</a></td>
|
77 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
78 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50.onnx">onnx</a></td>
|
79 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/args.txt">args</a></td>
|
80 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_log.txt">logs</a></td>
|
81 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_eval_linear_log.txt">eval logs</a></td>
|
82 |
+
</tr>
|
83 |
+
</table>
|
84 |
+
|
85 |
+
We also release XCiT models ([[`arXiv`](https://arxiv.org/abs/2106.09681)] [[`code`](https://github.com/facebookresearch/xcit)]) trained with DINO:
|
86 |
+
<table>
|
87 |
+
<tr>
|
88 |
+
<th>arch</th>
|
89 |
+
<th>params</th>
|
90 |
+
<th>k-nn</th>
|
91 |
+
<th>linear</th>
|
92 |
+
<th colspan="5">download</th>
|
93 |
+
</tr>
|
94 |
+
<tr>
|
95 |
+
<td>xcit_small_12_p16</td>
|
96 |
+
<td>26M</td>
|
97 |
+
<td>76.0%</td>
|
98 |
+
<td>77.8%</td>
|
99 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth">backbone only</a></td>
|
100 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
101 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/args.txt">args</a></td>
|
102 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_log.txt">logs</a></td>
|
103 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_eval_linear_log.txt">eval</a></td>
|
104 |
+
</tr>
|
105 |
+
<tr>
|
106 |
+
<td>xcit_small_12_p8</td>
|
107 |
+
<td>26M</td>
|
108 |
+
<td>77.1%</td>
|
109 |
+
<td>79.2%</td>
|
110 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth">backbone only</a></td>
|
111 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
112 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/args.txt">args</a></td>
|
113 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_log.txt">logs</a></td>
|
114 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_eval_linear_log.txt">eval</a></td>
|
115 |
+
</tr>
|
116 |
+
<tr>
|
117 |
+
<td>xcit_medium_24_p16</td>
|
118 |
+
<td>84M</td>
|
119 |
+
<td>76.4%</td>
|
120 |
+
<td>78.8%</td>
|
121 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth">backbone only</a></td>
|
122 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
123 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/args.txt">args</a></td>
|
124 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_log.txt">logs</a></td>
|
125 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_eval_linear_log.txt">eval</a></td>
|
126 |
+
</tr>
|
127 |
+
<tr>
|
128 |
+
<td>xcit_medium_24_p8</td>
|
129 |
+
<td>84M</td>
|
130 |
+
<td>77.9%</td>
|
131 |
+
<td>80.3%</td>
|
132 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth">backbone only</a></td>
|
133 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_full_checkpoint.pth">full ckpt</a></td>
|
134 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/args.txt">args</a></td>
|
135 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_log.txt">logs</a></td>
|
136 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_eval_linear_log.txt">eval</a></td>
|
137 |
+
</tr>
|
138 |
+
</table>
|
139 |
+
|
140 |
+
### Pretrained models on PyTorch Hub
|
141 |
+
```python
|
142 |
+
import torch
|
143 |
+
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
|
144 |
+
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
|
145 |
+
vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
|
146 |
+
vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
|
147 |
+
xcit_small_12_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16')
|
148 |
+
xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')
|
149 |
+
xcit_medium_24_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')
|
150 |
+
xcit_medium_24_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
|
151 |
+
resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
|
152 |
+
```
|
153 |
+
|
154 |
+
## Training
|
155 |
+
|
156 |
+
### Documentation
|
157 |
+
Please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the `args` column of the [pretrained models section](https://github.com/facebookresearch/dino#pretrained-models). For a glimpse at the full documentation of DINO training please run:
|
158 |
+
```
|
159 |
+
python main_dino.py --help
|
160 |
+
```
|
161 |
+
|
162 |
+
### Vanilla DINO training :sauropod:
|
163 |
+
Run DINO with ViT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
|
164 |
+
```
|
165 |
+
python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
166 |
+
```
|
167 |
+
|
168 |
+
### Multi-node training
|
169 |
+
We use Slurm and [submitit](https://github.com/facebookincubator/submitit) (`pip install submitit`). To train on 2 nodes with 8 GPUs each (total 16 GPUs):
|
170 |
+
```
|
171 |
+
python run_with_submitit.py --nodes 2 --ngpus 8 --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
172 |
+
```
|
173 |
+
|
174 |
+
<details>
|
175 |
+
<summary>
|
176 |
+
DINO with ViT-base network.
|
177 |
+
</summary>
|
178 |
+
|
179 |
+
```
|
180 |
+
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
181 |
+
```
|
182 |
+
|
183 |
+
</details>
|
184 |
+
|
185 |
+
### Boosting DINO performance :t-rex:
|
186 |
+
You can improve the performance of the vanilla run by:
|
187 |
+
- training for more epochs: `--epochs 300`,
|
188 |
+
- increasing the teacher temperature: `--teacher_temp 0.07 --warmup_teacher_temp_epochs 30`.
|
189 |
+
- removing last layer normalization (only safe with `--arch vit_small`): `--norm_last_layer false`,
|
190 |
+
|
191 |
+
<details>
|
192 |
+
<summary>
|
193 |
+
Full command.
|
194 |
+
</summary>
|
195 |
+
|
196 |
+
```
|
197 |
+
python run_with_submitit.py --arch vit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
198 |
+
```
|
199 |
+
|
200 |
+
</details>
|
201 |
+
|
202 |
+
The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
|
203 |
+
|
204 |
+
### ResNet-50 and other convnets trainings
|
205 |
+
This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training logs](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) and [final checkpoint](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_checkpoint.pth) for this run.
|
206 |
+
```
|
207 |
+
python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
208 |
+
```
|
209 |
+
|
210 |
+
## Self-attention visualization
|
211 |
+
You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:
|
212 |
+
```
|
213 |
+
python visualize_attention.py
|
214 |
+
```
|
215 |
+
|
216 |
+
<div align="center">
|
217 |
+
<img width="100%" alt="Self-attention from a Vision Transformer with 8x8 patches trained with DINO" src=".github/attention_maps.png">
|
218 |
+
</div>
|
219 |
+
|
220 |
+
## Self-attention video generation
|
221 |
+
You can generate videos like the one on the blog post with `video_generation.py`.
|
222 |
+
|
223 |
+
https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4
|
224 |
+
|
225 |
+
Extract frames from input video and generate attention video:
|
226 |
+
```
|
227 |
+
python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
|
228 |
+
--input_path input/video.mp4 \
|
229 |
+
--output_path output/ \
|
230 |
+
--fps 25
|
231 |
+
```
|
232 |
+
|
233 |
+
Use folder of frames already extracted and generate attention video:
|
234 |
+
```
|
235 |
+
python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \
|
236 |
+
--input_path output/frames/ \
|
237 |
+
--output_path output/ \
|
238 |
+
--resize 256 \
|
239 |
+
```
|
240 |
+
|
241 |
+
Only generate video from folder of attention maps images:
|
242 |
+
```
|
243 |
+
python video_generation.py --input_path output/attention \
|
244 |
+
--output_path output/ \
|
245 |
+
--video_only \
|
246 |
+
--video_format avi
|
247 |
+
```
|
248 |
+
|
249 |
+
|
250 |
+
## Evaluation: k-NN classification on ImageNet
|
251 |
+
To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:
|
252 |
+
```
|
253 |
+
python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet
|
254 |
+
```
|
255 |
+
If you choose not to specify `--pretrained_weights`, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:
|
256 |
+
```
|
257 |
+
python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet
|
258 |
+
```
|
259 |
+
|
260 |
+
## Evaluation: Linear classification on ImageNet
|
261 |
+
To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:
|
262 |
+
```
|
263 |
+
python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet
|
264 |
+
```
|
265 |
+
|
266 |
+
We release the logs and weights from evaluating the different models:
|
267 |
+
|
268 |
+
<table>
|
269 |
+
<tr>
|
270 |
+
<th>arch</th>
|
271 |
+
<th>top-1 ImageNet</th>
|
272 |
+
<th colspan="2">linear evaluation</th>
|
273 |
+
</tr>
|
274 |
+
<tr>
|
275 |
+
<td>ViT-S/16</td>
|
276 |
+
<td>77.0%</td>
|
277 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth">linear weights</a></td>
|
278 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_eval_linear_log.txt">logs</a></td>
|
279 |
+
</tr>
|
280 |
+
<tr>
|
281 |
+
<td>ViT-S/8</td>
|
282 |
+
<td>79.7%</td>
|
283 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth">linear weights</a></td>
|
284 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_eval_linear_log.txt">logs</a></td>
|
285 |
+
</tr>
|
286 |
+
<tr>
|
287 |
+
<td>ViT-B/16</td>
|
288 |
+
<td>78.2%</td>
|
289 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth">linear weights</a></td>
|
290 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_eval_linear_log.txt">logs</a></td>
|
291 |
+
</tr>
|
292 |
+
<tr>
|
293 |
+
<td>ViT-B/8</td>
|
294 |
+
<td>80.1%</td>
|
295 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth">linear weights</a></td>
|
296 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_eval_linear_log.txt">logs</a></td>
|
297 |
+
</tr>
|
298 |
+
<tr>
|
299 |
+
<td>xcit_small_12_p16</td>
|
300 |
+
<td>77.8%</td>
|
301 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_linearweights.pth">linear weights</a></td>
|
302 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain_eval_linear_log.txt">logs</a></td>
|
303 |
+
</tr>
|
304 |
+
<tr>
|
305 |
+
<td>xcit_small_12_p8</td>
|
306 |
+
<td>79.2%</td>
|
307 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_linearweights.pth">linear weights</a></td>
|
308 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_eval_linear_log.txt">logs</a></td>
|
309 |
+
</tr>
|
310 |
+
<tr>
|
311 |
+
<td>xcit_medium_24_p16</td>
|
312 |
+
<td>78.8%</td>
|
313 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_linearweights.pth">linear weights</a></td>
|
314 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain_eval_linear_log.txt">logs</a></td>
|
315 |
+
</tr>
|
316 |
+
<tr>
|
317 |
+
<td>xcit_medium_24_p8</td>
|
318 |
+
<td>80.3%</td>
|
319 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_linearweights.pth">linear weights</a></td>
|
320 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_eval_linear_log.txt">logs</a></td>
|
321 |
+
</tr>
|
322 |
+
<tr>
|
323 |
+
<td>ResNet-50</td>
|
324 |
+
<td>75.3%</td>
|
325 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_linearweights.pth">linear weights</a></td>
|
326 |
+
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_eval_linear_log.txt">logs</a></td>
|
327 |
+
</tr>
|
328 |
+
</table>
|
329 |
+
|
330 |
+
You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines:
|
331 |
+
```
|
332 |
+
python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train
|
333 |
+
```
|
334 |
+
|
335 |
+
```
|
336 |
+
python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train
|
337 |
+
```
|
338 |
+
|
339 |
+
```
|
340 |
+
python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
|
341 |
+
```
|
342 |
+
|
343 |
+
```
|
344 |
+
python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
|
345 |
+
```
|
346 |
+
|
347 |
+
```
|
348 |
+
python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train
|
349 |
+
```
|
350 |
+
|
351 |
+
## Evaluation: DAVIS 2017 Video object segmentation
|
352 |
+
Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment.
|
353 |
+
|
354 |
+
**Step 1: Prepare DAVIS 2017 data**
|
355 |
+
```
|
356 |
+
cd $HOME
|
357 |
+
git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017
|
358 |
+
./data/get_davis.sh
|
359 |
+
```
|
360 |
+
|
361 |
+
**Step 2: Video object segmentation**
|
362 |
+
```
|
363 |
+
python eval_video_segmentation.py --data_path $HOME/davis-2017/DAVIS/ --output_dir /path/to/saving_dir
|
364 |
+
```
|
365 |
+
|
366 |
+
**Step 3: Evaluate the obtained segmentation**
|
367 |
+
```
|
368 |
+
git clone https://github.com/davisvideochallenge/davis2017-evaluation $HOME/davis2017-evaluation
|
369 |
+
python $HOME/davis2017-evaluation/evaluation_method.py --task semi-supervised --results_path /path/to/saving_dir --davis_path $HOME/davis-2017/DAVIS/
|
370 |
+
```
|
371 |
+
|
372 |
+
## Evaluation: Image Retrieval on revisited Oxford and Paris
|
373 |
+
Step 1: Prepare revisited Oxford and Paris by following [this repo](https://github.com/filipradenovic/revisitop).
|
374 |
+
|
375 |
+
Step 2: Image retrieval (if you do not specify weights with `--pretrained_weights` then by default [DINO weights pretrained on Google Landmark v2 dataset](https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth) will be used).
|
376 |
+
|
377 |
+
Paris:
|
378 |
+
```
|
379 |
+
python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 512 --multiscale 1 --data_path /path/to/revisited_paris_oxford/ --dataset rparis6k
|
380 |
+
```
|
381 |
+
|
382 |
+
Oxford:
|
383 |
+
```
|
384 |
+
python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 224 --multiscale 0 --data_path /path/to/revisited_paris_oxford/ --dataset roxford5k
|
385 |
+
```
|
386 |
+
|
387 |
+
## Evaluation: Copy detection on Copydays
|
388 |
+
Step 1: Prepare [Copydays dataset](https://lear.inrialpes.fr/~jegou/data.php#copydays).
|
389 |
+
|
390 |
+
Step 2 (opt): Prepare a set of image distractors and a set of images on which to learn the whitening operator.
|
391 |
+
In our paper, we use 10k random images from YFCC100M as distractors and 20k random images from YFCC100M (different from the distractors) for computing the whitening operation.
|
392 |
+
|
393 |
+
Step 3: Run copy detection:
|
394 |
+
```
|
395 |
+
python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_copy_detection.py --data_path /path/to/copydays/ --whitening_path /path/to/whitening_data/ --distractors_path /path/to/distractors/
|
396 |
+
```
|
397 |
+
We report result on the strong subset. For example in the stdout from the command above we get: `eval on strong mAP=0.858`.
|
398 |
+
|
399 |
+
## License
|
400 |
+
This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
|
401 |
+
|
402 |
+
## Citation
|
403 |
+
If you find this repository useful, please consider giving a star :star: and citation :t-rex::
|
404 |
+
```
|
405 |
+
@inproceedings{caron2021emerging,
|
406 |
+
title={Emerging Properties in Self-Supervised Vision Transformers},
|
407 |
+
author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
|
408 |
+
booktitle={Proceedings of the International Conference on Computer Vision (ICCV)},
|
409 |
+
year={2021}
|
410 |
+
}
|
411 |
+
```
|
dino/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from os.path import dirname, join
|
3 |
+
|
4 |
+
sys.path.insert(0, join(dirname(__file__), "."))
|
dino/eval_copy_detection.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import pickle
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.distributed as dist
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
from torchvision import models as torchvision_models
|
24 |
+
from torchvision import transforms as pth_transforms
|
25 |
+
from PIL import Image, ImageFile
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import utils
|
29 |
+
import vision_transformer as vits
|
30 |
+
from eval_knn import extract_features
|
31 |
+
|
32 |
+
|
33 |
+
class CopydaysDataset:
|
34 |
+
def __init__(self, basedir):
|
35 |
+
self.basedir = basedir
|
36 |
+
self.block_names = (
|
37 |
+
["original", "strong"]
|
38 |
+
+ ["jpegqual/%d" % i for i in [3, 5, 8, 10, 15, 20, 30, 50, 75]]
|
39 |
+
+ ["crops/%d" % i for i in [10, 15, 20, 30, 40, 50, 60, 70, 80]]
|
40 |
+
)
|
41 |
+
self.nblocks = len(self.block_names)
|
42 |
+
|
43 |
+
self.query_blocks = range(self.nblocks)
|
44 |
+
self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157
|
45 |
+
self.q_block_sizes[1] = 229
|
46 |
+
# search only among originals
|
47 |
+
self.database_blocks = [0]
|
48 |
+
|
49 |
+
def get_block(self, i):
|
50 |
+
dirname = self.basedir + "/" + self.block_names[i]
|
51 |
+
fnames = [
|
52 |
+
dirname + "/" + fname
|
53 |
+
for fname in sorted(os.listdir(dirname))
|
54 |
+
if fname.endswith(".jpg")
|
55 |
+
]
|
56 |
+
return fnames
|
57 |
+
|
58 |
+
def get_block_filenames(self, subdir_name):
|
59 |
+
dirname = self.basedir + "/" + subdir_name
|
60 |
+
return [
|
61 |
+
fname for fname in sorted(os.listdir(dirname)) if fname.endswith(".jpg")
|
62 |
+
]
|
63 |
+
|
64 |
+
def eval_result(self, ids, distances):
|
65 |
+
j0 = 0
|
66 |
+
for i in range(self.nblocks):
|
67 |
+
j1 = j0 + self.q_block_sizes[i]
|
68 |
+
block_name = self.block_names[i]
|
69 |
+
I = ids[j0:j1] # block size
|
70 |
+
sum_AP = 0
|
71 |
+
if block_name != "strong":
|
72 |
+
# 1:1 mapping of files to names
|
73 |
+
positives_per_query = [[i] for i in range(j1 - j0)]
|
74 |
+
else:
|
75 |
+
originals = self.get_block_filenames("original")
|
76 |
+
strongs = self.get_block_filenames("strong")
|
77 |
+
|
78 |
+
# check if prefixes match
|
79 |
+
positives_per_query = [
|
80 |
+
[j for j, bname in enumerate(originals) if bname[:4] == qname[:4]]
|
81 |
+
for qname in strongs
|
82 |
+
]
|
83 |
+
|
84 |
+
for qno, Iline in enumerate(I):
|
85 |
+
positives = positives_per_query[qno]
|
86 |
+
ranks = []
|
87 |
+
for rank, bno in enumerate(Iline):
|
88 |
+
if bno in positives:
|
89 |
+
ranks.append(rank)
|
90 |
+
sum_AP += score_ap_from_ranks_1(ranks, len(positives))
|
91 |
+
|
92 |
+
print("eval on %s mAP=%.3f" % (block_name, sum_AP / (j1 - j0)))
|
93 |
+
j0 = j1
|
94 |
+
|
95 |
+
|
96 |
+
# from the Holidays evaluation package
|
97 |
+
def score_ap_from_ranks_1(ranks, nres):
|
98 |
+
"""Compute the average precision of one search.
|
99 |
+
ranks = ordered list of ranks of true positives
|
100 |
+
nres = total number of positives in dataset
|
101 |
+
"""
|
102 |
+
|
103 |
+
# accumulate trapezoids in PR-plot
|
104 |
+
ap = 0.0
|
105 |
+
|
106 |
+
# All have an x-size of:
|
107 |
+
recall_step = 1.0 / nres
|
108 |
+
|
109 |
+
for ntp, rank in enumerate(ranks):
|
110 |
+
|
111 |
+
# y-size on left side of trapezoid:
|
112 |
+
# ntp = nb of true positives so far
|
113 |
+
# rank = nb of retrieved items so far
|
114 |
+
if rank == 0:
|
115 |
+
precision_0 = 1.0
|
116 |
+
else:
|
117 |
+
precision_0 = ntp / float(rank)
|
118 |
+
|
119 |
+
# y-size on right side of trapezoid:
|
120 |
+
# ntp and rank are increased by one
|
121 |
+
precision_1 = (ntp + 1) / float(rank + 1)
|
122 |
+
|
123 |
+
ap += (precision_1 + precision_0) * recall_step / 2.0
|
124 |
+
|
125 |
+
return ap
|
126 |
+
|
127 |
+
|
128 |
+
class ImgListDataset(torch.utils.data.Dataset):
|
129 |
+
def __init__(self, img_list, transform=None):
|
130 |
+
self.samples = img_list
|
131 |
+
self.transform = transform
|
132 |
+
|
133 |
+
def __getitem__(self, i):
|
134 |
+
with open(self.samples[i], "rb") as f:
|
135 |
+
img = Image.open(f)
|
136 |
+
img = img.convert("RGB")
|
137 |
+
if self.transform is not None:
|
138 |
+
img = self.transform(img)
|
139 |
+
return img, i
|
140 |
+
|
141 |
+
def __len__(self):
|
142 |
+
return len(self.samples)
|
143 |
+
|
144 |
+
|
145 |
+
def is_image_file(s):
|
146 |
+
ext = s.split(".")[-1]
|
147 |
+
if ext in ["jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"]:
|
148 |
+
return True
|
149 |
+
return False
|
150 |
+
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
+
def extract_features(image_list, model, args):
|
154 |
+
transform = pth_transforms.Compose(
|
155 |
+
[
|
156 |
+
pth_transforms.Resize((args.imsize, args.imsize), interpolation=3),
|
157 |
+
pth_transforms.ToTensor(),
|
158 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
159 |
+
]
|
160 |
+
)
|
161 |
+
tempdataset = ImgListDataset(image_list, transform=transform)
|
162 |
+
data_loader = torch.utils.data.DataLoader(
|
163 |
+
tempdataset,
|
164 |
+
batch_size=args.batch_size_per_gpu,
|
165 |
+
num_workers=args.num_workers,
|
166 |
+
drop_last=False,
|
167 |
+
sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False),
|
168 |
+
)
|
169 |
+
features = None
|
170 |
+
for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10):
|
171 |
+
samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True)
|
172 |
+
feats = model.get_intermediate_layers(samples, n=1)[0].clone()
|
173 |
+
|
174 |
+
cls_output_token = feats[:, 0, :] # [CLS] token
|
175 |
+
# GeM with exponent 4 for output patch tokens
|
176 |
+
b, h, w, d = (
|
177 |
+
len(samples),
|
178 |
+
int(samples.shape[-2] / model.patch_embed.patch_size),
|
179 |
+
int(samples.shape[-1] / model.patch_embed.patch_size),
|
180 |
+
feats.shape[-1],
|
181 |
+
)
|
182 |
+
feats = feats[:, 1:, :].reshape(b, h, w, d)
|
183 |
+
feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2)
|
184 |
+
feats = (
|
185 |
+
nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1.0 / 4).reshape(b, -1)
|
186 |
+
)
|
187 |
+
# concatenate [CLS] token and GeM pooled patch tokens
|
188 |
+
feats = torch.cat((cls_output_token, feats), dim=1)
|
189 |
+
|
190 |
+
# init storage feature matrix
|
191 |
+
if dist.get_rank() == 0 and features is None:
|
192 |
+
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
|
193 |
+
if args.use_cuda:
|
194 |
+
features = features.cuda(non_blocking=True)
|
195 |
+
|
196 |
+
# get indexes from all processes
|
197 |
+
y_all = torch.empty(
|
198 |
+
dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device
|
199 |
+
)
|
200 |
+
y_l = list(y_all.unbind(0))
|
201 |
+
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
|
202 |
+
y_all_reduce.wait()
|
203 |
+
index_all = torch.cat(y_l)
|
204 |
+
|
205 |
+
# share features between processes
|
206 |
+
feats_all = torch.empty(
|
207 |
+
dist.get_world_size(),
|
208 |
+
feats.size(0),
|
209 |
+
feats.size(1),
|
210 |
+
dtype=feats.dtype,
|
211 |
+
device=feats.device,
|
212 |
+
)
|
213 |
+
output_l = list(feats_all.unbind(0))
|
214 |
+
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
|
215 |
+
output_all_reduce.wait()
|
216 |
+
|
217 |
+
# update storage feature matrix
|
218 |
+
if dist.get_rank() == 0:
|
219 |
+
if args.use_cuda:
|
220 |
+
features.index_copy_(0, index_all, torch.cat(output_l))
|
221 |
+
else:
|
222 |
+
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
|
223 |
+
return features # features is still None for every rank which is not 0 (main)
|
224 |
+
|
225 |
+
|
226 |
+
if __name__ == "__main__":
|
227 |
+
parser = argparse.ArgumentParser("Copy detection on Copydays")
|
228 |
+
parser.add_argument(
|
229 |
+
"--data_path",
|
230 |
+
default="/path/to/copydays/",
|
231 |
+
type=str,
|
232 |
+
help="See https://lear.inrialpes.fr/~jegou/data.php#copydays",
|
233 |
+
)
|
234 |
+
parser.add_argument(
|
235 |
+
"--whitening_path",
|
236 |
+
default="/path/to/whitening_data/",
|
237 |
+
type=str,
|
238 |
+
help="""Path to directory with images used for computing the whitening operator.
|
239 |
+
In our paper, we use 20k random images from YFCC100M.""",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--distractors_path",
|
243 |
+
default="/path/to/distractors/",
|
244 |
+
type=str,
|
245 |
+
help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.",
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--imsize", default=320, type=int, help="Image size (square image)"
|
249 |
+
)
|
250 |
+
parser.add_argument(
|
251 |
+
"--batch_size_per_gpu", default=16, type=int, help="Per-GPU batch-size"
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--pretrained_weights",
|
255 |
+
default="",
|
256 |
+
type=str,
|
257 |
+
help="Path to pretrained weights to evaluate.",
|
258 |
+
)
|
259 |
+
parser.add_argument("--use_cuda", default=True, type=utils.bool_flag)
|
260 |
+
parser.add_argument("--arch", default="vit_base", type=str, help="Architecture")
|
261 |
+
parser.add_argument(
|
262 |
+
"--patch_size", default=8, type=int, help="Patch resolution of the model."
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--checkpoint_key",
|
266 |
+
default="teacher",
|
267 |
+
type=str,
|
268 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--num_workers",
|
272 |
+
default=10,
|
273 |
+
type=int,
|
274 |
+
help="Number of data loading workers per GPU.",
|
275 |
+
)
|
276 |
+
parser.add_argument(
|
277 |
+
"--dist_url",
|
278 |
+
default="env://",
|
279 |
+
type=str,
|
280 |
+
help="""url used to set up
|
281 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""",
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--local_rank",
|
285 |
+
default=0,
|
286 |
+
type=int,
|
287 |
+
help="Please ignore and do not set this argument.",
|
288 |
+
)
|
289 |
+
args = parser.parse_args()
|
290 |
+
|
291 |
+
utils.init_distributed_mode(args)
|
292 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
293 |
+
print(
|
294 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
295 |
+
)
|
296 |
+
cudnn.benchmark = True
|
297 |
+
|
298 |
+
# ============ building network ... ============
|
299 |
+
if "vit" in args.arch:
|
300 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
301 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
302 |
+
else:
|
303 |
+
print(f"Architecture {args.arch} non supported")
|
304 |
+
sys.exit(1)
|
305 |
+
if args.use_cuda:
|
306 |
+
model.cuda()
|
307 |
+
model.eval()
|
308 |
+
utils.load_pretrained_weights(
|
309 |
+
model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
|
310 |
+
)
|
311 |
+
|
312 |
+
dataset = CopydaysDataset(args.data_path)
|
313 |
+
|
314 |
+
# ============ Extract features ... ============
|
315 |
+
# extract features for queries
|
316 |
+
queries = []
|
317 |
+
for q in dataset.query_blocks:
|
318 |
+
queries.append(extract_features(dataset.get_block(q), model, args))
|
319 |
+
if utils.get_rank() == 0:
|
320 |
+
queries = torch.cat(queries)
|
321 |
+
print(f"Extraction of queries features done. Shape: {queries.shape}")
|
322 |
+
|
323 |
+
# extract features for database
|
324 |
+
database = []
|
325 |
+
for b in dataset.database_blocks:
|
326 |
+
database.append(extract_features(dataset.get_block(b), model, args))
|
327 |
+
|
328 |
+
# extract features for distractors
|
329 |
+
if os.path.isdir(args.distractors_path):
|
330 |
+
print("Using distractors...")
|
331 |
+
list_distractors = [
|
332 |
+
os.path.join(args.distractors_path, s)
|
333 |
+
for s in os.listdir(args.distractors_path)
|
334 |
+
if is_image_file(s)
|
335 |
+
]
|
336 |
+
database.append(extract_features(list_distractors, model, args))
|
337 |
+
if utils.get_rank() == 0:
|
338 |
+
database = torch.cat(database)
|
339 |
+
print(
|
340 |
+
f"Extraction of database and distractors features done. Shape: {database.shape}"
|
341 |
+
)
|
342 |
+
|
343 |
+
# ============ Whitening ... ============
|
344 |
+
if os.path.isdir(args.whitening_path):
|
345 |
+
print(
|
346 |
+
f"Extracting features on images from {args.whitening_path} for learning the whitening operator."
|
347 |
+
)
|
348 |
+
list_whit = [
|
349 |
+
os.path.join(args.whitening_path, s)
|
350 |
+
for s in os.listdir(args.whitening_path)
|
351 |
+
if is_image_file(s)
|
352 |
+
]
|
353 |
+
features_for_whitening = extract_features(list_whit, model, args)
|
354 |
+
if utils.get_rank() == 0:
|
355 |
+
# center
|
356 |
+
mean_feature = torch.mean(features_for_whitening, dim=0)
|
357 |
+
database -= mean_feature
|
358 |
+
queries -= mean_feature
|
359 |
+
pca = utils.PCA(dim=database.shape[-1], whit=0.5)
|
360 |
+
# compute covariance
|
361 |
+
cov = (
|
362 |
+
torch.mm(features_for_whitening.T, features_for_whitening)
|
363 |
+
/ features_for_whitening.shape[0]
|
364 |
+
)
|
365 |
+
pca.train_pca(cov.cpu().numpy())
|
366 |
+
database = pca.apply(database)
|
367 |
+
queries = pca.apply(queries)
|
368 |
+
|
369 |
+
# ============ Copy detection ... ============
|
370 |
+
if utils.get_rank() == 0:
|
371 |
+
# l2 normalize the features
|
372 |
+
database = nn.functional.normalize(database, dim=1, p=2)
|
373 |
+
queries = nn.functional.normalize(queries, dim=1, p=2)
|
374 |
+
|
375 |
+
# similarity
|
376 |
+
similarity = torch.mm(queries, database.T)
|
377 |
+
distances, indices = similarity.topk(20, largest=True, sorted=True)
|
378 |
+
|
379 |
+
# evaluate
|
380 |
+
retrieved = dataset.eval_result(indices, distances)
|
381 |
+
dist.barrier()
|
dino/eval_image_retrieval.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import pickle
|
17 |
+
import argparse
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.distributed as dist
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
from torchvision import models as torchvision_models
|
24 |
+
from torchvision import transforms as pth_transforms
|
25 |
+
from PIL import Image, ImageFile
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import utils
|
29 |
+
import vision_transformer as vits
|
30 |
+
from eval_knn import extract_features
|
31 |
+
|
32 |
+
|
33 |
+
class OxfordParisDataset(torch.utils.data.Dataset):
|
34 |
+
def __init__(self, dir_main, dataset, split, transform=None, imsize=None):
|
35 |
+
if dataset not in ["roxford5k", "rparis6k"]:
|
36 |
+
raise ValueError("Unknown dataset: {}!".format(dataset))
|
37 |
+
|
38 |
+
# loading imlist, qimlist, and gnd, in cfg as a dict
|
39 |
+
gnd_fname = os.path.join(dir_main, dataset, "gnd_{}.pkl".format(dataset))
|
40 |
+
with open(gnd_fname, "rb") as f:
|
41 |
+
cfg = pickle.load(f)
|
42 |
+
cfg["gnd_fname"] = gnd_fname
|
43 |
+
cfg["ext"] = ".jpg"
|
44 |
+
cfg["qext"] = ".jpg"
|
45 |
+
cfg["dir_data"] = os.path.join(dir_main, dataset)
|
46 |
+
cfg["dir_images"] = os.path.join(cfg["dir_data"], "jpg")
|
47 |
+
cfg["n"] = len(cfg["imlist"])
|
48 |
+
cfg["nq"] = len(cfg["qimlist"])
|
49 |
+
cfg["im_fname"] = config_imname
|
50 |
+
cfg["qim_fname"] = config_qimname
|
51 |
+
cfg["dataset"] = dataset
|
52 |
+
self.cfg = cfg
|
53 |
+
|
54 |
+
self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"]
|
55 |
+
self.transform = transform
|
56 |
+
self.imsize = imsize
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.samples)
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg")
|
63 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
64 |
+
with open(path, "rb") as f:
|
65 |
+
img = Image.open(f)
|
66 |
+
img = img.convert("RGB")
|
67 |
+
if self.imsize is not None:
|
68 |
+
img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS)
|
69 |
+
if self.transform is not None:
|
70 |
+
img = self.transform(img)
|
71 |
+
return img, index
|
72 |
+
|
73 |
+
|
74 |
+
def config_imname(cfg, i):
|
75 |
+
return os.path.join(cfg["dir_images"], cfg["imlist"][i] + cfg["ext"])
|
76 |
+
|
77 |
+
|
78 |
+
def config_qimname(cfg, i):
|
79 |
+
return os.path.join(cfg["dir_images"], cfg["qimlist"][i] + cfg["qext"])
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
parser = argparse.ArgumentParser("Image Retrieval on revisited Paris and Oxford")
|
84 |
+
parser.add_argument(
|
85 |
+
"--data_path", default="/path/to/revisited_paris_oxford/", type=str
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--dataset", default="roxford5k", type=str, choices=["roxford5k", "rparis6k"]
|
89 |
+
)
|
90 |
+
parser.add_argument("--multiscale", default=False, type=utils.bool_flag)
|
91 |
+
parser.add_argument("--imsize", default=224, type=int, help="Image size")
|
92 |
+
parser.add_argument(
|
93 |
+
"--pretrained_weights",
|
94 |
+
default="",
|
95 |
+
type=str,
|
96 |
+
help="Path to pretrained weights to evaluate.",
|
97 |
+
)
|
98 |
+
parser.add_argument("--use_cuda", default=True, type=utils.bool_flag)
|
99 |
+
parser.add_argument("--arch", default="vit_small", type=str, help="Architecture")
|
100 |
+
parser.add_argument(
|
101 |
+
"--patch_size", default=16, type=int, help="Patch resolution of the model."
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--checkpoint_key",
|
105 |
+
default="teacher",
|
106 |
+
type=str,
|
107 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--num_workers",
|
111 |
+
default=10,
|
112 |
+
type=int,
|
113 |
+
help="Number of data loading workers per GPU.",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--dist_url",
|
117 |
+
default="env://",
|
118 |
+
type=str,
|
119 |
+
help="""url used to set up
|
120 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""",
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--local_rank",
|
124 |
+
default=0,
|
125 |
+
type=int,
|
126 |
+
help="Please ignore and do not set this argument.",
|
127 |
+
)
|
128 |
+
args = parser.parse_args()
|
129 |
+
|
130 |
+
utils.init_distributed_mode(args)
|
131 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
132 |
+
print(
|
133 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
134 |
+
)
|
135 |
+
cudnn.benchmark = True
|
136 |
+
|
137 |
+
# ============ preparing data ... ============
|
138 |
+
transform = pth_transforms.Compose(
|
139 |
+
[
|
140 |
+
pth_transforms.ToTensor(),
|
141 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
142 |
+
]
|
143 |
+
)
|
144 |
+
dataset_train = OxfordParisDataset(
|
145 |
+
args.data_path,
|
146 |
+
args.dataset,
|
147 |
+
split="train",
|
148 |
+
transform=transform,
|
149 |
+
imsize=args.imsize,
|
150 |
+
)
|
151 |
+
dataset_query = OxfordParisDataset(
|
152 |
+
args.data_path,
|
153 |
+
args.dataset,
|
154 |
+
split="query",
|
155 |
+
transform=transform,
|
156 |
+
imsize=args.imsize,
|
157 |
+
)
|
158 |
+
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
|
159 |
+
data_loader_train = torch.utils.data.DataLoader(
|
160 |
+
dataset_train,
|
161 |
+
sampler=sampler,
|
162 |
+
batch_size=1,
|
163 |
+
num_workers=args.num_workers,
|
164 |
+
pin_memory=True,
|
165 |
+
drop_last=False,
|
166 |
+
)
|
167 |
+
data_loader_query = torch.utils.data.DataLoader(
|
168 |
+
dataset_query,
|
169 |
+
batch_size=1,
|
170 |
+
num_workers=args.num_workers,
|
171 |
+
pin_memory=True,
|
172 |
+
drop_last=False,
|
173 |
+
)
|
174 |
+
print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs")
|
175 |
+
|
176 |
+
# ============ building network ... ============
|
177 |
+
if "vit" in args.arch:
|
178 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
179 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
180 |
+
elif "xcit" in args.arch:
|
181 |
+
model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0)
|
182 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
183 |
+
model = torchvision_models.__dict__[args.arch](num_classes=0)
|
184 |
+
else:
|
185 |
+
print(f"Architecture {args.arch} non supported")
|
186 |
+
sys.exit(1)
|
187 |
+
if args.use_cuda:
|
188 |
+
model.cuda()
|
189 |
+
model.eval()
|
190 |
+
|
191 |
+
# load pretrained weights
|
192 |
+
if os.path.isfile(args.pretrained_weights):
|
193 |
+
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
|
194 |
+
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
|
195 |
+
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
|
196 |
+
state_dict = state_dict[args.checkpoint_key]
|
197 |
+
# remove `module.` prefix
|
198 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
199 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
200 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
201 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
202 |
+
print(
|
203 |
+
"Pretrained weights found at {} and loaded with msg: {}".format(
|
204 |
+
args.pretrained_weights, msg
|
205 |
+
)
|
206 |
+
)
|
207 |
+
elif args.arch == "vit_small" and args.patch_size == 16:
|
208 |
+
print(
|
209 |
+
"Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2."
|
210 |
+
)
|
211 |
+
model.load_state_dict(
|
212 |
+
torch.hub.load_state_dict_from_url(
|
213 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"
|
214 |
+
)
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
print("Warning: We use random weights.")
|
218 |
+
|
219 |
+
############################################################################
|
220 |
+
# Step 1: extract features
|
221 |
+
train_features = extract_features(
|
222 |
+
model, data_loader_train, args.use_cuda, multiscale=args.multiscale
|
223 |
+
)
|
224 |
+
query_features = extract_features(
|
225 |
+
model, data_loader_query, args.use_cuda, multiscale=args.multiscale
|
226 |
+
)
|
227 |
+
|
228 |
+
if utils.get_rank() == 0: # only rank 0 will work from now on
|
229 |
+
# normalize features
|
230 |
+
train_features = nn.functional.normalize(train_features, dim=1, p=2)
|
231 |
+
query_features = nn.functional.normalize(query_features, dim=1, p=2)
|
232 |
+
|
233 |
+
############################################################################
|
234 |
+
# Step 2: similarity
|
235 |
+
sim = torch.mm(train_features, query_features.T)
|
236 |
+
ranks = torch.argsort(-sim, dim=0).cpu().numpy()
|
237 |
+
|
238 |
+
############################################################################
|
239 |
+
# Step 3: evaluate
|
240 |
+
gnd = dataset_train.cfg["gnd"]
|
241 |
+
# evaluate ranks
|
242 |
+
ks = [1, 5, 10]
|
243 |
+
# search for easy & hard
|
244 |
+
gnd_t = []
|
245 |
+
for i in range(len(gnd)):
|
246 |
+
g = {}
|
247 |
+
g["ok"] = np.concatenate([gnd[i]["easy"], gnd[i]["hard"]])
|
248 |
+
g["junk"] = np.concatenate([gnd[i]["junk"]])
|
249 |
+
gnd_t.append(g)
|
250 |
+
mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks)
|
251 |
+
# search for hard
|
252 |
+
gnd_t = []
|
253 |
+
for i in range(len(gnd)):
|
254 |
+
g = {}
|
255 |
+
g["ok"] = np.concatenate([gnd[i]["hard"]])
|
256 |
+
g["junk"] = np.concatenate([gnd[i]["junk"], gnd[i]["easy"]])
|
257 |
+
gnd_t.append(g)
|
258 |
+
mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks)
|
259 |
+
print(
|
260 |
+
">> {}: mAP M: {}, H: {}".format(
|
261 |
+
args.dataset,
|
262 |
+
np.around(mapM * 100, decimals=2),
|
263 |
+
np.around(mapH * 100, decimals=2),
|
264 |
+
)
|
265 |
+
)
|
266 |
+
print(
|
267 |
+
">> {}: mP@k{} M: {}, H: {}".format(
|
268 |
+
args.dataset,
|
269 |
+
np.array(ks),
|
270 |
+
np.around(mprM * 100, decimals=2),
|
271 |
+
np.around(mprH * 100, decimals=2),
|
272 |
+
)
|
273 |
+
)
|
274 |
+
dist.barrier()
|
dino/eval_knn.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
import torch.distributed as dist
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
from torchvision import datasets
|
23 |
+
from torchvision import transforms as pth_transforms
|
24 |
+
from torchvision import models as torchvision_models
|
25 |
+
|
26 |
+
import utils
|
27 |
+
import vision_transformer as vits
|
28 |
+
|
29 |
+
|
30 |
+
def extract_feature_pipeline(args):
|
31 |
+
# ============ preparing data ... ============
|
32 |
+
transform = pth_transforms.Compose(
|
33 |
+
[
|
34 |
+
pth_transforms.Resize(256, interpolation=3),
|
35 |
+
pth_transforms.CenterCrop(224),
|
36 |
+
pth_transforms.ToTensor(),
|
37 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
dataset_train = ReturnIndexDataset(
|
41 |
+
os.path.join(args.data_path, "train"), transform=transform
|
42 |
+
)
|
43 |
+
dataset_val = ReturnIndexDataset(
|
44 |
+
os.path.join(args.data_path, "val"), transform=transform
|
45 |
+
)
|
46 |
+
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
|
47 |
+
data_loader_train = torch.utils.data.DataLoader(
|
48 |
+
dataset_train,
|
49 |
+
sampler=sampler,
|
50 |
+
batch_size=args.batch_size_per_gpu,
|
51 |
+
num_workers=args.num_workers,
|
52 |
+
pin_memory=True,
|
53 |
+
drop_last=False,
|
54 |
+
)
|
55 |
+
data_loader_val = torch.utils.data.DataLoader(
|
56 |
+
dataset_val,
|
57 |
+
batch_size=args.batch_size_per_gpu,
|
58 |
+
num_workers=args.num_workers,
|
59 |
+
pin_memory=True,
|
60 |
+
drop_last=False,
|
61 |
+
)
|
62 |
+
print(
|
63 |
+
f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs."
|
64 |
+
)
|
65 |
+
|
66 |
+
# ============ building network ... ============
|
67 |
+
if "vit" in args.arch:
|
68 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
69 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
70 |
+
elif "xcit" in args.arch:
|
71 |
+
model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0)
|
72 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
73 |
+
model = torchvision_models.__dict__[args.arch](num_classes=0)
|
74 |
+
model.fc = nn.Identity()
|
75 |
+
else:
|
76 |
+
print(f"Architecture {args.arch} non supported")
|
77 |
+
sys.exit(1)
|
78 |
+
model.cuda()
|
79 |
+
utils.load_pretrained_weights(
|
80 |
+
model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
|
81 |
+
)
|
82 |
+
model.eval()
|
83 |
+
|
84 |
+
# ============ extract features ... ============
|
85 |
+
print("Extracting features for train set...")
|
86 |
+
train_features = extract_features(model, data_loader_train, args.use_cuda)
|
87 |
+
print("Extracting features for val set...")
|
88 |
+
test_features = extract_features(model, data_loader_val, args.use_cuda)
|
89 |
+
|
90 |
+
if utils.get_rank() == 0:
|
91 |
+
train_features = nn.functional.normalize(train_features, dim=1, p=2)
|
92 |
+
test_features = nn.functional.normalize(test_features, dim=1, p=2)
|
93 |
+
|
94 |
+
train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
|
95 |
+
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
|
96 |
+
# save features and labels
|
97 |
+
if args.dump_features and dist.get_rank() == 0:
|
98 |
+
torch.save(
|
99 |
+
train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")
|
100 |
+
)
|
101 |
+
torch.save(
|
102 |
+
test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")
|
103 |
+
)
|
104 |
+
torch.save(
|
105 |
+
train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")
|
106 |
+
)
|
107 |
+
torch.save(
|
108 |
+
test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")
|
109 |
+
)
|
110 |
+
return train_features, test_features, train_labels, test_labels
|
111 |
+
|
112 |
+
|
113 |
+
@torch.no_grad()
|
114 |
+
def extract_features(model, data_loader, use_cuda=True, multiscale=False):
|
115 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
116 |
+
features = None
|
117 |
+
for samples, index in metric_logger.log_every(data_loader, 10):
|
118 |
+
samples = samples.cuda(non_blocking=True)
|
119 |
+
index = index.cuda(non_blocking=True)
|
120 |
+
if multiscale:
|
121 |
+
feats = utils.multi_scale(samples, model)
|
122 |
+
else:
|
123 |
+
feats = model(samples).clone()
|
124 |
+
|
125 |
+
# init storage feature matrix
|
126 |
+
if dist.get_rank() == 0 and features is None:
|
127 |
+
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
|
128 |
+
if use_cuda:
|
129 |
+
features = features.cuda(non_blocking=True)
|
130 |
+
print(f"Storing features into tensor of shape {features.shape}")
|
131 |
+
|
132 |
+
# get indexes from all processes
|
133 |
+
y_all = torch.empty(
|
134 |
+
dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device
|
135 |
+
)
|
136 |
+
y_l = list(y_all.unbind(0))
|
137 |
+
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
|
138 |
+
y_all_reduce.wait()
|
139 |
+
index_all = torch.cat(y_l)
|
140 |
+
|
141 |
+
# share features between processes
|
142 |
+
feats_all = torch.empty(
|
143 |
+
dist.get_world_size(),
|
144 |
+
feats.size(0),
|
145 |
+
feats.size(1),
|
146 |
+
dtype=feats.dtype,
|
147 |
+
device=feats.device,
|
148 |
+
)
|
149 |
+
output_l = list(feats_all.unbind(0))
|
150 |
+
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
|
151 |
+
output_all_reduce.wait()
|
152 |
+
|
153 |
+
# update storage feature matrix
|
154 |
+
if dist.get_rank() == 0:
|
155 |
+
if use_cuda:
|
156 |
+
features.index_copy_(0, index_all, torch.cat(output_l))
|
157 |
+
else:
|
158 |
+
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
|
159 |
+
return features
|
160 |
+
|
161 |
+
|
162 |
+
@torch.no_grad()
|
163 |
+
def knn_classifier(
|
164 |
+
train_features, train_labels, test_features, test_labels, k, T, num_classes=1000
|
165 |
+
):
|
166 |
+
top1, top5, total = 0.0, 0.0, 0
|
167 |
+
train_features = train_features.t()
|
168 |
+
num_test_images, num_chunks = test_labels.shape[0], 100
|
169 |
+
imgs_per_chunk = num_test_images // num_chunks
|
170 |
+
retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
|
171 |
+
for idx in range(0, num_test_images, imgs_per_chunk):
|
172 |
+
# get the features for test images
|
173 |
+
features = test_features[idx : min((idx + imgs_per_chunk), num_test_images), :]
|
174 |
+
targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
|
175 |
+
batch_size = targets.shape[0]
|
176 |
+
|
177 |
+
# calculate the dot product and compute top-k neighbors
|
178 |
+
similarity = torch.mm(features, train_features)
|
179 |
+
distances, indices = similarity.topk(k, largest=True, sorted=True)
|
180 |
+
candidates = train_labels.view(1, -1).expand(batch_size, -1)
|
181 |
+
retrieved_neighbors = torch.gather(candidates, 1, indices)
|
182 |
+
|
183 |
+
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
|
184 |
+
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
|
185 |
+
distances_transform = distances.clone().div_(T).exp_()
|
186 |
+
probs = torch.sum(
|
187 |
+
torch.mul(
|
188 |
+
retrieval_one_hot.view(batch_size, -1, num_classes),
|
189 |
+
distances_transform.view(batch_size, -1, 1),
|
190 |
+
),
|
191 |
+
1,
|
192 |
+
)
|
193 |
+
_, predictions = probs.sort(1, True)
|
194 |
+
|
195 |
+
# find the predictions that match the target
|
196 |
+
correct = predictions.eq(targets.data.view(-1, 1))
|
197 |
+
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
|
198 |
+
top5 = (
|
199 |
+
top5 + correct.narrow(1, 0, min(5, k)).sum().item()
|
200 |
+
) # top5 does not make sense if k < 5
|
201 |
+
total += targets.size(0)
|
202 |
+
top1 = top1 * 100.0 / total
|
203 |
+
top5 = top5 * 100.0 / total
|
204 |
+
return top1, top5
|
205 |
+
|
206 |
+
|
207 |
+
class ReturnIndexDataset(datasets.ImageFolder):
|
208 |
+
def __getitem__(self, idx):
|
209 |
+
img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
|
210 |
+
return img, idx
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
parser = argparse.ArgumentParser("Evaluation with weighted k-NN on ImageNet")
|
215 |
+
parser.add_argument(
|
216 |
+
"--batch_size_per_gpu", default=128, type=int, help="Per-GPU batch-size"
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--nb_knn",
|
220 |
+
default=[10, 20, 100, 200],
|
221 |
+
nargs="+",
|
222 |
+
type=int,
|
223 |
+
help="Number of NN to use. 20 is usually working the best.",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--temperature",
|
227 |
+
default=0.07,
|
228 |
+
type=float,
|
229 |
+
help="Temperature used in the voting coefficient",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--pretrained_weights",
|
233 |
+
default="",
|
234 |
+
type=str,
|
235 |
+
help="Path to pretrained weights to evaluate.",
|
236 |
+
)
|
237 |
+
parser.add_argument(
|
238 |
+
"--use_cuda",
|
239 |
+
default=True,
|
240 |
+
type=utils.bool_flag,
|
241 |
+
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM",
|
242 |
+
)
|
243 |
+
parser.add_argument("--arch", default="vit_small", type=str, help="Architecture")
|
244 |
+
parser.add_argument(
|
245 |
+
"--patch_size", default=16, type=int, help="Patch resolution of the model."
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--checkpoint_key",
|
249 |
+
default="teacher",
|
250 |
+
type=str,
|
251 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--dump_features",
|
255 |
+
default=None,
|
256 |
+
help="Path where to save computed features, empty for no saving",
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--load_features",
|
260 |
+
default=None,
|
261 |
+
help="""If the features have
|
262 |
+
already been computed, where to find them.""",
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--num_workers",
|
266 |
+
default=10,
|
267 |
+
type=int,
|
268 |
+
help="Number of data loading workers per GPU.",
|
269 |
+
)
|
270 |
+
parser.add_argument(
|
271 |
+
"--dist_url",
|
272 |
+
default="env://",
|
273 |
+
type=str,
|
274 |
+
help="""url used to set up
|
275 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""",
|
276 |
+
)
|
277 |
+
parser.add_argument(
|
278 |
+
"--local_rank",
|
279 |
+
default=0,
|
280 |
+
type=int,
|
281 |
+
help="Please ignore and do not set this argument.",
|
282 |
+
)
|
283 |
+
parser.add_argument("--data_path", default="/path/to/imagenet/", type=str)
|
284 |
+
args = parser.parse_args()
|
285 |
+
|
286 |
+
utils.init_distributed_mode(args)
|
287 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
288 |
+
print(
|
289 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
290 |
+
)
|
291 |
+
cudnn.benchmark = True
|
292 |
+
|
293 |
+
if args.load_features:
|
294 |
+
train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
|
295 |
+
test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
|
296 |
+
train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
|
297 |
+
test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
|
298 |
+
else:
|
299 |
+
# need to extract features !
|
300 |
+
(
|
301 |
+
train_features,
|
302 |
+
test_features,
|
303 |
+
train_labels,
|
304 |
+
test_labels,
|
305 |
+
) = extract_feature_pipeline(args)
|
306 |
+
|
307 |
+
if utils.get_rank() == 0:
|
308 |
+
if args.use_cuda:
|
309 |
+
train_features = train_features.cuda()
|
310 |
+
test_features = test_features.cuda()
|
311 |
+
train_labels = train_labels.cuda()
|
312 |
+
test_labels = test_labels.cuda()
|
313 |
+
|
314 |
+
print("Features are ready!\nStart the k-NN classification.")
|
315 |
+
for k in args.nb_knn:
|
316 |
+
top1, top5 = knn_classifier(
|
317 |
+
train_features,
|
318 |
+
train_labels,
|
319 |
+
test_features,
|
320 |
+
test_labels,
|
321 |
+
k,
|
322 |
+
args.temperature,
|
323 |
+
)
|
324 |
+
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
|
325 |
+
dist.barrier()
|
dino/eval_linear.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import argparse
|
16 |
+
import json
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.distributed as dist
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
from torchvision import datasets
|
24 |
+
from torchvision import transforms as pth_transforms
|
25 |
+
from torchvision import models as torchvision_models
|
26 |
+
|
27 |
+
import utils
|
28 |
+
import vision_transformer as vits
|
29 |
+
|
30 |
+
|
31 |
+
def eval_linear(args):
|
32 |
+
utils.init_distributed_mode(args)
|
33 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
34 |
+
print(
|
35 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
36 |
+
)
|
37 |
+
cudnn.benchmark = True
|
38 |
+
|
39 |
+
# ============ building network ... ============
|
40 |
+
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
|
41 |
+
if args.arch in vits.__dict__.keys():
|
42 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
43 |
+
embed_dim = model.embed_dim * (
|
44 |
+
args.n_last_blocks + int(args.avgpool_patchtokens)
|
45 |
+
)
|
46 |
+
# if the network is a XCiT
|
47 |
+
elif "xcit" in args.arch:
|
48 |
+
model = torch.hub.load("facebookresearch/xcit:main", args.arch, num_classes=0)
|
49 |
+
embed_dim = model.embed_dim
|
50 |
+
# otherwise, we check if the architecture is in torchvision models
|
51 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
52 |
+
model = torchvision_models.__dict__[args.arch]()
|
53 |
+
embed_dim = model.fc.weight.shape[1]
|
54 |
+
model.fc = nn.Identity()
|
55 |
+
else:
|
56 |
+
print(f"Unknow architecture: {args.arch}")
|
57 |
+
sys.exit(1)
|
58 |
+
model.cuda()
|
59 |
+
model.eval()
|
60 |
+
# load weights to evaluate
|
61 |
+
utils.load_pretrained_weights(
|
62 |
+
model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
|
63 |
+
)
|
64 |
+
print(f"Model {args.arch} built.")
|
65 |
+
|
66 |
+
linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
|
67 |
+
linear_classifier = linear_classifier.cuda()
|
68 |
+
linear_classifier = nn.parallel.DistributedDataParallel(
|
69 |
+
linear_classifier, device_ids=[args.gpu]
|
70 |
+
)
|
71 |
+
|
72 |
+
# ============ preparing data ... ============
|
73 |
+
val_transform = pth_transforms.Compose(
|
74 |
+
[
|
75 |
+
pth_transforms.Resize(256, interpolation=3),
|
76 |
+
pth_transforms.CenterCrop(224),
|
77 |
+
pth_transforms.ToTensor(),
|
78 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
dataset_val = datasets.ImageFolder(
|
82 |
+
os.path.join(args.data_path, "val"), transform=val_transform
|
83 |
+
)
|
84 |
+
val_loader = torch.utils.data.DataLoader(
|
85 |
+
dataset_val,
|
86 |
+
batch_size=args.batch_size_per_gpu,
|
87 |
+
num_workers=args.num_workers,
|
88 |
+
pin_memory=True,
|
89 |
+
)
|
90 |
+
|
91 |
+
if args.evaluate:
|
92 |
+
utils.load_pretrained_linear_weights(
|
93 |
+
linear_classifier, args.arch, args.patch_size
|
94 |
+
)
|
95 |
+
test_stats = validate_network(
|
96 |
+
val_loader,
|
97 |
+
model,
|
98 |
+
linear_classifier,
|
99 |
+
args.n_last_blocks,
|
100 |
+
args.avgpool_patchtokens,
|
101 |
+
)
|
102 |
+
print(
|
103 |
+
f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
|
104 |
+
)
|
105 |
+
return
|
106 |
+
|
107 |
+
train_transform = pth_transforms.Compose(
|
108 |
+
[
|
109 |
+
pth_transforms.RandomResizedCrop(224),
|
110 |
+
pth_transforms.RandomHorizontalFlip(),
|
111 |
+
pth_transforms.ToTensor(),
|
112 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
113 |
+
]
|
114 |
+
)
|
115 |
+
dataset_train = datasets.ImageFolder(
|
116 |
+
os.path.join(args.data_path, "train"), transform=train_transform
|
117 |
+
)
|
118 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
|
119 |
+
train_loader = torch.utils.data.DataLoader(
|
120 |
+
dataset_train,
|
121 |
+
sampler=sampler,
|
122 |
+
batch_size=args.batch_size_per_gpu,
|
123 |
+
num_workers=args.num_workers,
|
124 |
+
pin_memory=True,
|
125 |
+
)
|
126 |
+
print(
|
127 |
+
f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs."
|
128 |
+
)
|
129 |
+
|
130 |
+
# set optimizer
|
131 |
+
optimizer = torch.optim.SGD(
|
132 |
+
linear_classifier.parameters(),
|
133 |
+
args.lr
|
134 |
+
* (args.batch_size_per_gpu * utils.get_world_size())
|
135 |
+
/ 256.0, # linear scaling rule
|
136 |
+
momentum=0.9,
|
137 |
+
weight_decay=0, # we do not apply weight decay
|
138 |
+
)
|
139 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
140 |
+
optimizer, args.epochs, eta_min=0
|
141 |
+
)
|
142 |
+
|
143 |
+
# Optionally resume from a checkpoint
|
144 |
+
to_restore = {"epoch": 0, "best_acc": 0.0}
|
145 |
+
utils.restart_from_checkpoint(
|
146 |
+
os.path.join(args.output_dir, "checkpoint.pth.tar"),
|
147 |
+
run_variables=to_restore,
|
148 |
+
state_dict=linear_classifier,
|
149 |
+
optimizer=optimizer,
|
150 |
+
scheduler=scheduler,
|
151 |
+
)
|
152 |
+
start_epoch = to_restore["epoch"]
|
153 |
+
best_acc = to_restore["best_acc"]
|
154 |
+
|
155 |
+
for epoch in range(start_epoch, args.epochs):
|
156 |
+
train_loader.sampler.set_epoch(epoch)
|
157 |
+
|
158 |
+
train_stats = train(
|
159 |
+
model,
|
160 |
+
linear_classifier,
|
161 |
+
optimizer,
|
162 |
+
train_loader,
|
163 |
+
epoch,
|
164 |
+
args.n_last_blocks,
|
165 |
+
args.avgpool_patchtokens,
|
166 |
+
)
|
167 |
+
scheduler.step()
|
168 |
+
|
169 |
+
log_stats = {
|
170 |
+
**{f"train_{k}": v for k, v in train_stats.items()},
|
171 |
+
"epoch": epoch,
|
172 |
+
}
|
173 |
+
if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
|
174 |
+
test_stats = validate_network(
|
175 |
+
val_loader,
|
176 |
+
model,
|
177 |
+
linear_classifier,
|
178 |
+
args.n_last_blocks,
|
179 |
+
args.avgpool_patchtokens,
|
180 |
+
)
|
181 |
+
print(
|
182 |
+
f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%"
|
183 |
+
)
|
184 |
+
best_acc = max(best_acc, test_stats["acc1"])
|
185 |
+
print(f"Max accuracy so far: {best_acc:.2f}%")
|
186 |
+
log_stats = {
|
187 |
+
**{k: v for k, v in log_stats.items()},
|
188 |
+
**{f"test_{k}": v for k, v in test_stats.items()},
|
189 |
+
}
|
190 |
+
if utils.is_main_process():
|
191 |
+
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
192 |
+
f.write(json.dumps(log_stats) + "\n")
|
193 |
+
save_dict = {
|
194 |
+
"epoch": epoch + 1,
|
195 |
+
"state_dict": linear_classifier.state_dict(),
|
196 |
+
"optimizer": optimizer.state_dict(),
|
197 |
+
"scheduler": scheduler.state_dict(),
|
198 |
+
"best_acc": best_acc,
|
199 |
+
}
|
200 |
+
torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
|
201 |
+
print(
|
202 |
+
"Training of the supervised linear classifier on frozen features completed.\n"
|
203 |
+
"Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
|
208 |
+
linear_classifier.train()
|
209 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
210 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
211 |
+
header = "Epoch: [{}]".format(epoch)
|
212 |
+
for (inp, target) in metric_logger.log_every(loader, 20, header):
|
213 |
+
# move to gpu
|
214 |
+
inp = inp.cuda(non_blocking=True)
|
215 |
+
target = target.cuda(non_blocking=True)
|
216 |
+
|
217 |
+
# forward
|
218 |
+
with torch.no_grad():
|
219 |
+
if "vit" in args.arch:
|
220 |
+
intermediate_output = model.get_intermediate_layers(inp, n)
|
221 |
+
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
|
222 |
+
if avgpool:
|
223 |
+
output = torch.cat(
|
224 |
+
(
|
225 |
+
output.unsqueeze(-1),
|
226 |
+
torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(
|
227 |
+
-1
|
228 |
+
),
|
229 |
+
),
|
230 |
+
dim=-1,
|
231 |
+
)
|
232 |
+
output = output.reshape(output.shape[0], -1)
|
233 |
+
else:
|
234 |
+
output = model(inp)
|
235 |
+
output = linear_classifier(output)
|
236 |
+
|
237 |
+
# compute cross entropy loss
|
238 |
+
loss = nn.CrossEntropyLoss()(output, target)
|
239 |
+
|
240 |
+
# compute the gradients
|
241 |
+
optimizer.zero_grad()
|
242 |
+
loss.backward()
|
243 |
+
|
244 |
+
# step
|
245 |
+
optimizer.step()
|
246 |
+
|
247 |
+
# log
|
248 |
+
torch.cuda.synchronize()
|
249 |
+
metric_logger.update(loss=loss.item())
|
250 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
251 |
+
# gather the stats from all processes
|
252 |
+
metric_logger.synchronize_between_processes()
|
253 |
+
print("Averaged stats:", metric_logger)
|
254 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
255 |
+
|
256 |
+
|
257 |
+
@torch.no_grad()
|
258 |
+
def validate_network(val_loader, model, linear_classifier, n, avgpool):
|
259 |
+
linear_classifier.eval()
|
260 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
261 |
+
header = "Test:"
|
262 |
+
for inp, target in metric_logger.log_every(val_loader, 20, header):
|
263 |
+
# move to gpu
|
264 |
+
inp = inp.cuda(non_blocking=True)
|
265 |
+
target = target.cuda(non_blocking=True)
|
266 |
+
|
267 |
+
# forward
|
268 |
+
with torch.no_grad():
|
269 |
+
if "vit" in args.arch:
|
270 |
+
intermediate_output = model.get_intermediate_layers(inp, n)
|
271 |
+
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
|
272 |
+
if avgpool:
|
273 |
+
output = torch.cat(
|
274 |
+
(
|
275 |
+
output.unsqueeze(-1),
|
276 |
+
torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(
|
277 |
+
-1
|
278 |
+
),
|
279 |
+
),
|
280 |
+
dim=-1,
|
281 |
+
)
|
282 |
+
output = output.reshape(output.shape[0], -1)
|
283 |
+
else:
|
284 |
+
output = model(inp)
|
285 |
+
output = linear_classifier(output)
|
286 |
+
loss = nn.CrossEntropyLoss()(output, target)
|
287 |
+
|
288 |
+
if linear_classifier.module.num_labels >= 5:
|
289 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
290 |
+
else:
|
291 |
+
(acc1,) = utils.accuracy(output, target, topk=(1,))
|
292 |
+
|
293 |
+
batch_size = inp.shape[0]
|
294 |
+
metric_logger.update(loss=loss.item())
|
295 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
296 |
+
if linear_classifier.module.num_labels >= 5:
|
297 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
298 |
+
if linear_classifier.module.num_labels >= 5:
|
299 |
+
print(
|
300 |
+
"* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format(
|
301 |
+
top1=metric_logger.acc1,
|
302 |
+
top5=metric_logger.acc5,
|
303 |
+
losses=metric_logger.loss,
|
304 |
+
)
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
print(
|
308 |
+
"* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}".format(
|
309 |
+
top1=metric_logger.acc1, losses=metric_logger.loss
|
310 |
+
)
|
311 |
+
)
|
312 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
313 |
+
|
314 |
+
|
315 |
+
class LinearClassifier(nn.Module):
|
316 |
+
"""Linear layer to train on top of frozen features"""
|
317 |
+
|
318 |
+
def __init__(self, dim, num_labels=1000):
|
319 |
+
super(LinearClassifier, self).__init__()
|
320 |
+
self.num_labels = num_labels
|
321 |
+
self.linear = nn.Linear(dim, num_labels)
|
322 |
+
self.linear.weight.data.normal_(mean=0.0, std=0.01)
|
323 |
+
self.linear.bias.data.zero_()
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
# flatten
|
327 |
+
x = x.view(x.size(0), -1)
|
328 |
+
|
329 |
+
# linear layer
|
330 |
+
return self.linear(x)
|
331 |
+
|
332 |
+
|
333 |
+
if __name__ == "__main__":
|
334 |
+
parser = argparse.ArgumentParser(
|
335 |
+
"Evaluation with linear classification on ImageNet"
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--n_last_blocks",
|
339 |
+
default=4,
|
340 |
+
type=int,
|
341 |
+
help="""Concatenate [CLS] tokens
|
342 |
+
for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""",
|
343 |
+
)
|
344 |
+
parser.add_argument(
|
345 |
+
"--avgpool_patchtokens",
|
346 |
+
default=False,
|
347 |
+
type=utils.bool_flag,
|
348 |
+
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
|
349 |
+
We typically set this to False for ViT-Small and to True with ViT-Base.""",
|
350 |
+
)
|
351 |
+
parser.add_argument("--arch", default="vit_small", type=str, help="Architecture")
|
352 |
+
parser.add_argument(
|
353 |
+
"--patch_size", default=16, type=int, help="Patch resolution of the model."
|
354 |
+
)
|
355 |
+
parser.add_argument(
|
356 |
+
"--pretrained_weights",
|
357 |
+
default="",
|
358 |
+
type=str,
|
359 |
+
help="Path to pretrained weights to evaluate.",
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--checkpoint_key",
|
363 |
+
default="teacher",
|
364 |
+
type=str,
|
365 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
366 |
+
)
|
367 |
+
parser.add_argument(
|
368 |
+
"--epochs", default=100, type=int, help="Number of epochs of training."
|
369 |
+
)
|
370 |
+
parser.add_argument(
|
371 |
+
"--lr",
|
372 |
+
default=0.001,
|
373 |
+
type=float,
|
374 |
+
help="""Learning rate at the beginning of
|
375 |
+
training (highest LR used during training). The learning rate is linearly scaled
|
376 |
+
with the batch size, and specified here for a reference batch size of 256.
|
377 |
+
We recommend tweaking the LR depending on the checkpoint evaluated.""",
|
378 |
+
)
|
379 |
+
parser.add_argument(
|
380 |
+
"--batch_size_per_gpu", default=128, type=int, help="Per-GPU batch-size"
|
381 |
+
)
|
382 |
+
parser.add_argument(
|
383 |
+
"--dist_url",
|
384 |
+
default="env://",
|
385 |
+
type=str,
|
386 |
+
help="""url used to set up
|
387 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""",
|
388 |
+
)
|
389 |
+
parser.add_argument(
|
390 |
+
"--local_rank",
|
391 |
+
default=0,
|
392 |
+
type=int,
|
393 |
+
help="Please ignore and do not set this argument.",
|
394 |
+
)
|
395 |
+
parser.add_argument("--data_path", default="/path/to/imagenet/", type=str)
|
396 |
+
parser.add_argument(
|
397 |
+
"--num_workers",
|
398 |
+
default=10,
|
399 |
+
type=int,
|
400 |
+
help="Number of data loading workers per GPU.",
|
401 |
+
)
|
402 |
+
parser.add_argument(
|
403 |
+
"--val_freq", default=1, type=int, help="Epoch frequency for validation."
|
404 |
+
)
|
405 |
+
parser.add_argument(
|
406 |
+
"--output_dir", default=".", help="Path to save logs and checkpoints"
|
407 |
+
)
|
408 |
+
parser.add_argument(
|
409 |
+
"--num_labels",
|
410 |
+
default=1000,
|
411 |
+
type=int,
|
412 |
+
help="Number of labels for linear classifier",
|
413 |
+
)
|
414 |
+
parser.add_argument(
|
415 |
+
"--evaluate",
|
416 |
+
dest="evaluate",
|
417 |
+
action="store_true",
|
418 |
+
help="evaluate model on validation set",
|
419 |
+
)
|
420 |
+
args = parser.parse_args()
|
421 |
+
eval_linear(args)
|
dino/eval_video_segmentation.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Some parts are taken from https://github.com/Liusifei/UVC
|
16 |
+
"""
|
17 |
+
import os
|
18 |
+
import copy
|
19 |
+
import glob
|
20 |
+
import queue
|
21 |
+
from urllib.request import urlopen
|
22 |
+
import argparse
|
23 |
+
import numpy as np
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
import cv2
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from torch.nn import functional as F
|
30 |
+
from PIL import Image
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
import utils
|
34 |
+
import vision_transformer as vits
|
35 |
+
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def eval_video_tracking_davis(
|
39 |
+
args, model, frame_list, video_dir, first_seg, seg_ori, color_palette
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Evaluate tracking on a video given first frame & segmentation
|
43 |
+
"""
|
44 |
+
video_folder = os.path.join(args.output_dir, video_dir.split("/")[-1])
|
45 |
+
os.makedirs(video_folder, exist_ok=True)
|
46 |
+
|
47 |
+
# The queue stores the n preceeding frames
|
48 |
+
que = queue.Queue(args.n_last_frames)
|
49 |
+
|
50 |
+
# first frame
|
51 |
+
frame1, ori_h, ori_w = read_frame(frame_list[0])
|
52 |
+
# extract first frame feature
|
53 |
+
frame1_feat = extract_feature(model, frame1).T # dim x h*w
|
54 |
+
|
55 |
+
# saving first segmentation
|
56 |
+
out_path = os.path.join(video_folder, "00000.png")
|
57 |
+
imwrite_indexed(out_path, seg_ori, color_palette)
|
58 |
+
mask_neighborhood = None
|
59 |
+
for cnt in tqdm(range(1, len(frame_list))):
|
60 |
+
frame_tar = read_frame(frame_list[cnt])[0]
|
61 |
+
|
62 |
+
# we use the first segmentation and the n previous ones
|
63 |
+
used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)]
|
64 |
+
used_segs = [first_seg] + [pair[1] for pair in list(que.queue)]
|
65 |
+
|
66 |
+
frame_tar_avg, feat_tar, mask_neighborhood = label_propagation(
|
67 |
+
args, model, frame_tar, used_frame_feats, used_segs, mask_neighborhood
|
68 |
+
)
|
69 |
+
|
70 |
+
# pop out oldest frame if neccessary
|
71 |
+
if que.qsize() == args.n_last_frames:
|
72 |
+
que.get()
|
73 |
+
# push current results into queue
|
74 |
+
seg = copy.deepcopy(frame_tar_avg)
|
75 |
+
que.put([feat_tar, seg])
|
76 |
+
|
77 |
+
# upsampling & argmax
|
78 |
+
frame_tar_avg = F.interpolate(
|
79 |
+
frame_tar_avg,
|
80 |
+
scale_factor=args.patch_size,
|
81 |
+
mode="bilinear",
|
82 |
+
align_corners=False,
|
83 |
+
recompute_scale_factor=False,
|
84 |
+
)[0]
|
85 |
+
frame_tar_avg = norm_mask(frame_tar_avg)
|
86 |
+
_, frame_tar_seg = torch.max(frame_tar_avg, dim=0)
|
87 |
+
|
88 |
+
# saving to disk
|
89 |
+
frame_tar_seg = np.array(frame_tar_seg.squeeze().cpu(), dtype=np.uint8)
|
90 |
+
frame_tar_seg = np.array(
|
91 |
+
Image.fromarray(frame_tar_seg).resize((ori_w, ori_h), 0)
|
92 |
+
)
|
93 |
+
frame_nm = frame_list[cnt].split("/")[-1].replace(".jpg", ".png")
|
94 |
+
imwrite_indexed(
|
95 |
+
os.path.join(video_folder, frame_nm), frame_tar_seg, color_palette
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def restrict_neighborhood(h, w):
|
100 |
+
# We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'')
|
101 |
+
mask = torch.zeros(h, w, h, w)
|
102 |
+
for i in range(h):
|
103 |
+
for j in range(w):
|
104 |
+
for p in range(2 * args.size_mask_neighborhood + 1):
|
105 |
+
for q in range(2 * args.size_mask_neighborhood + 1):
|
106 |
+
if (
|
107 |
+
i - args.size_mask_neighborhood + p < 0
|
108 |
+
or i - args.size_mask_neighborhood + p >= h
|
109 |
+
):
|
110 |
+
continue
|
111 |
+
if (
|
112 |
+
j - args.size_mask_neighborhood + q < 0
|
113 |
+
or j - args.size_mask_neighborhood + q >= w
|
114 |
+
):
|
115 |
+
continue
|
116 |
+
mask[
|
117 |
+
i,
|
118 |
+
j,
|
119 |
+
i - args.size_mask_neighborhood + p,
|
120 |
+
j - args.size_mask_neighborhood + q,
|
121 |
+
] = 1
|
122 |
+
|
123 |
+
mask = mask.reshape(h * w, h * w)
|
124 |
+
return mask.cuda(non_blocking=True)
|
125 |
+
|
126 |
+
|
127 |
+
def norm_mask(mask):
|
128 |
+
c, h, w = mask.size()
|
129 |
+
for cnt in range(c):
|
130 |
+
mask_cnt = mask[cnt, :, :]
|
131 |
+
if mask_cnt.max() > 0:
|
132 |
+
mask_cnt = mask_cnt - mask_cnt.min()
|
133 |
+
mask_cnt = mask_cnt / mask_cnt.max()
|
134 |
+
mask[cnt, :, :] = mask_cnt
|
135 |
+
return mask
|
136 |
+
|
137 |
+
|
138 |
+
def label_propagation(
|
139 |
+
args, model, frame_tar, list_frame_feats, list_segs, mask_neighborhood=None
|
140 |
+
):
|
141 |
+
"""
|
142 |
+
propagate segs of frames in list_frames to frame_tar
|
143 |
+
"""
|
144 |
+
## we only need to extract feature of the target frame
|
145 |
+
feat_tar, h, w = extract_feature(model, frame_tar, return_h_w=True)
|
146 |
+
|
147 |
+
return_feat_tar = feat_tar.T # dim x h*w
|
148 |
+
|
149 |
+
ncontext = len(list_frame_feats)
|
150 |
+
feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w
|
151 |
+
|
152 |
+
feat_tar = F.normalize(feat_tar, dim=1, p=2)
|
153 |
+
feat_sources = F.normalize(feat_sources, dim=1, p=2)
|
154 |
+
|
155 |
+
feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1)
|
156 |
+
aff = torch.exp(
|
157 |
+
torch.bmm(feat_tar, feat_sources) / 0.1
|
158 |
+
) # nmb_context x h*w (tar: query) x h*w (source: keys)
|
159 |
+
|
160 |
+
if args.size_mask_neighborhood > 0:
|
161 |
+
if mask_neighborhood is None:
|
162 |
+
mask_neighborhood = restrict_neighborhood(h, w)
|
163 |
+
mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1)
|
164 |
+
aff *= mask_neighborhood
|
165 |
+
|
166 |
+
aff = aff.transpose(2, 1).reshape(
|
167 |
+
-1, h * w
|
168 |
+
) # nmb_context*h*w (source: keys) x h*w (tar: queries)
|
169 |
+
tk_val, _ = torch.topk(aff, dim=0, k=args.topk)
|
170 |
+
tk_val_min, _ = torch.min(tk_val, dim=0)
|
171 |
+
aff[aff < tk_val_min] = 0
|
172 |
+
|
173 |
+
aff = aff / torch.sum(aff, keepdim=True, axis=0)
|
174 |
+
|
175 |
+
list_segs = [s.cuda() for s in list_segs]
|
176 |
+
segs = torch.cat(list_segs)
|
177 |
+
nmb_context, C, h, w = segs.shape
|
178 |
+
segs = (
|
179 |
+
segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T
|
180 |
+
) # C x nmb_context*h*w
|
181 |
+
seg_tar = torch.mm(segs, aff)
|
182 |
+
seg_tar = seg_tar.reshape(1, C, h, w)
|
183 |
+
return seg_tar, return_feat_tar, mask_neighborhood
|
184 |
+
|
185 |
+
|
186 |
+
def extract_feature(model, frame, return_h_w=False):
|
187 |
+
"""Extract one frame feature everytime."""
|
188 |
+
out = model.get_intermediate_layers(frame.unsqueeze(0).cuda(), n=1)[0]
|
189 |
+
out = out[:, 1:, :] # we discard the [CLS] token
|
190 |
+
h, w = int(frame.shape[1] / model.patch_embed.patch_size), int(
|
191 |
+
frame.shape[2] / model.patch_embed.patch_size
|
192 |
+
)
|
193 |
+
dim = out.shape[-1]
|
194 |
+
out = out[0].reshape(h, w, dim)
|
195 |
+
out = out.reshape(-1, dim)
|
196 |
+
if return_h_w:
|
197 |
+
return out, h, w
|
198 |
+
return out
|
199 |
+
|
200 |
+
|
201 |
+
def imwrite_indexed(filename, array, color_palette):
|
202 |
+
"""Save indexed png for DAVIS."""
|
203 |
+
if np.atleast_3d(array).shape[2] != 1:
|
204 |
+
raise Exception("Saving indexed PNGs requires 2D array.")
|
205 |
+
|
206 |
+
im = Image.fromarray(array)
|
207 |
+
im.putpalette(color_palette.ravel())
|
208 |
+
im.save(filename, format="PNG")
|
209 |
+
|
210 |
+
|
211 |
+
def to_one_hot(y_tensor, n_dims=None):
|
212 |
+
"""
|
213 |
+
Take integer y (tensor or variable) with n dims &
|
214 |
+
convert it to 1-hot representation with n+1 dims.
|
215 |
+
"""
|
216 |
+
if n_dims is None:
|
217 |
+
n_dims = int(y_tensor.max() + 1)
|
218 |
+
_, h, w = y_tensor.size()
|
219 |
+
y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
|
220 |
+
n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
|
221 |
+
y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
|
222 |
+
y_one_hot = y_one_hot.view(h, w, n_dims)
|
223 |
+
return y_one_hot.permute(2, 0, 1).unsqueeze(0)
|
224 |
+
|
225 |
+
|
226 |
+
def read_frame_list(video_dir):
|
227 |
+
frame_list = [img for img in glob.glob(os.path.join(video_dir, "*.jpg"))]
|
228 |
+
frame_list = sorted(frame_list)
|
229 |
+
return frame_list
|
230 |
+
|
231 |
+
|
232 |
+
def read_frame(frame_dir, scale_size=[480]):
|
233 |
+
"""
|
234 |
+
read a single frame & preprocess
|
235 |
+
"""
|
236 |
+
img = cv2.imread(frame_dir)
|
237 |
+
ori_h, ori_w, _ = img.shape
|
238 |
+
if len(scale_size) == 1:
|
239 |
+
if ori_h > ori_w:
|
240 |
+
tw = scale_size[0]
|
241 |
+
th = (tw * ori_h) / ori_w
|
242 |
+
th = int((th // 64) * 64)
|
243 |
+
else:
|
244 |
+
th = scale_size[0]
|
245 |
+
tw = (th * ori_w) / ori_h
|
246 |
+
tw = int((tw // 64) * 64)
|
247 |
+
else:
|
248 |
+
th, tw = scale_size
|
249 |
+
img = cv2.resize(img, (tw, th))
|
250 |
+
img = img.astype(np.float32)
|
251 |
+
img = img / 255.0
|
252 |
+
img = img[:, :, ::-1]
|
253 |
+
img = np.transpose(img.copy(), (2, 0, 1))
|
254 |
+
img = torch.from_numpy(img).float()
|
255 |
+
img = color_normalize(img)
|
256 |
+
return img, ori_h, ori_w
|
257 |
+
|
258 |
+
|
259 |
+
def read_seg(seg_dir, factor, scale_size=[480]):
|
260 |
+
seg = Image.open(seg_dir)
|
261 |
+
_w, _h = seg.size # note PIL.Image.Image's size is (w, h)
|
262 |
+
if len(scale_size) == 1:
|
263 |
+
if _w > _h:
|
264 |
+
_th = scale_size[0]
|
265 |
+
_tw = (_th * _w) / _h
|
266 |
+
_tw = int((_tw // 64) * 64)
|
267 |
+
else:
|
268 |
+
_tw = scale_size[0]
|
269 |
+
_th = (_tw * _h) / _w
|
270 |
+
_th = int((_th // 64) * 64)
|
271 |
+
else:
|
272 |
+
_th = scale_size[1]
|
273 |
+
_tw = scale_size[0]
|
274 |
+
small_seg = np.array(seg.resize((_tw // factor, _th // factor), 0))
|
275 |
+
small_seg = torch.from_numpy(small_seg.copy()).contiguous().float().unsqueeze(0)
|
276 |
+
return to_one_hot(small_seg), np.asarray(seg)
|
277 |
+
|
278 |
+
|
279 |
+
def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
|
280 |
+
for t, m, s in zip(x, mean, std):
|
281 |
+
t.sub_(m)
|
282 |
+
t.div_(s)
|
283 |
+
return x
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
parser = argparse.ArgumentParser(
|
288 |
+
"Evaluation with video object segmentation on DAVIS 2017"
|
289 |
+
)
|
290 |
+
parser.add_argument(
|
291 |
+
"--pretrained_weights",
|
292 |
+
default="",
|
293 |
+
type=str,
|
294 |
+
help="Path to pretrained weights to evaluate.",
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--arch",
|
298 |
+
default="vit_small",
|
299 |
+
type=str,
|
300 |
+
choices=["vit_tiny", "vit_small", "vit_base"],
|
301 |
+
help="Architecture (support only ViT atm).",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--patch_size", default=16, type=int, help="Patch resolution of the model."
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--checkpoint_key",
|
308 |
+
default="teacher",
|
309 |
+
type=str,
|
310 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
311 |
+
)
|
312 |
+
parser.add_argument(
|
313 |
+
"--output_dir", default=".", help="Path where to save segmentations"
|
314 |
+
)
|
315 |
+
parser.add_argument("--data_path", default="/path/to/davis/", type=str)
|
316 |
+
parser.add_argument(
|
317 |
+
"--n_last_frames", type=int, default=7, help="number of preceeding frames"
|
318 |
+
)
|
319 |
+
parser.add_argument(
|
320 |
+
"--size_mask_neighborhood",
|
321 |
+
default=12,
|
322 |
+
type=int,
|
323 |
+
help="We restrict the set of source nodes considered to a spatial neighborhood of the query node",
|
324 |
+
)
|
325 |
+
parser.add_argument(
|
326 |
+
"--topk", type=int, default=5, help="accumulate label from top k neighbors"
|
327 |
+
)
|
328 |
+
parser.add_argument(
|
329 |
+
"--bs", type=int, default=6, help="Batch size, try to reduce if OOM"
|
330 |
+
)
|
331 |
+
args = parser.parse_args()
|
332 |
+
|
333 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
334 |
+
print(
|
335 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
336 |
+
)
|
337 |
+
|
338 |
+
# building network
|
339 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
340 |
+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
341 |
+
model.cuda()
|
342 |
+
utils.load_pretrained_weights(
|
343 |
+
model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size
|
344 |
+
)
|
345 |
+
for param in model.parameters():
|
346 |
+
param.requires_grad = False
|
347 |
+
model.eval()
|
348 |
+
|
349 |
+
color_palette = []
|
350 |
+
for line in urlopen(
|
351 |
+
"https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"
|
352 |
+
):
|
353 |
+
color_palette.append(
|
354 |
+
[int(i) for i in line.decode("utf-8").split("\n")[0].split(" ")]
|
355 |
+
)
|
356 |
+
color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1, 3)
|
357 |
+
|
358 |
+
video_list = open(
|
359 |
+
os.path.join(args.data_path, "ImageSets/2017/val.txt")
|
360 |
+
).readlines()
|
361 |
+
for i, video_name in enumerate(video_list):
|
362 |
+
video_name = video_name.strip()
|
363 |
+
print(f"[{i}/{len(video_list)}] Begin to segmentate video {video_name}.")
|
364 |
+
video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name)
|
365 |
+
frame_list = read_frame_list(video_dir)
|
366 |
+
seg_path = (
|
367 |
+
frame_list[0].replace("JPEGImages", "Annotations").replace("jpg", "png")
|
368 |
+
)
|
369 |
+
first_seg, seg_ori = read_seg(seg_path, args.patch_size)
|
370 |
+
eval_video_tracking_davis(
|
371 |
+
args, model, frame_list, video_dir, first_seg, seg_ori, color_palette
|
372 |
+
)
|
dino/hubconf.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
from torchvision.models.resnet import resnet50
|
16 |
+
|
17 |
+
import vision_transformer as vits
|
18 |
+
|
19 |
+
dependencies = ["torch", "torchvision"]
|
20 |
+
|
21 |
+
|
22 |
+
def dino_vits16(pretrained=True, **kwargs):
|
23 |
+
"""
|
24 |
+
ViT-Small/16x16 pre-trained with DINO.
|
25 |
+
Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
|
26 |
+
"""
|
27 |
+
model = vits.__dict__["vit_small"](patch_size=16, num_classes=0, **kwargs)
|
28 |
+
if pretrained:
|
29 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
30 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
|
31 |
+
map_location="cpu",
|
32 |
+
)
|
33 |
+
model.load_state_dict(state_dict, strict=True)
|
34 |
+
return model
|
35 |
+
|
36 |
+
|
37 |
+
def dino_vits8(pretrained=True, **kwargs):
|
38 |
+
"""
|
39 |
+
ViT-Small/8x8 pre-trained with DINO.
|
40 |
+
Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
|
41 |
+
"""
|
42 |
+
model = vits.__dict__["vit_small"](patch_size=8, num_classes=0, **kwargs)
|
43 |
+
if pretrained:
|
44 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
45 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
|
46 |
+
map_location="cpu",
|
47 |
+
)
|
48 |
+
model.load_state_dict(state_dict, strict=True)
|
49 |
+
return model
|
50 |
+
|
51 |
+
|
52 |
+
def dino_vitb16(pretrained=True, **kwargs):
|
53 |
+
"""
|
54 |
+
ViT-Base/16x16 pre-trained with DINO.
|
55 |
+
Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
|
56 |
+
"""
|
57 |
+
model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs)
|
58 |
+
if pretrained:
|
59 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
60 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
|
61 |
+
map_location="cpu",
|
62 |
+
)
|
63 |
+
model.load_state_dict(state_dict, strict=True)
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def dino_vitb8(pretrained=True, **kwargs):
|
68 |
+
"""
|
69 |
+
ViT-Base/8x8 pre-trained with DINO.
|
70 |
+
Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
|
71 |
+
"""
|
72 |
+
model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
|
73 |
+
if pretrained:
|
74 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
75 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
|
76 |
+
map_location="cpu",
|
77 |
+
)
|
78 |
+
model.load_state_dict(state_dict, strict=True)
|
79 |
+
return model
|
80 |
+
|
81 |
+
|
82 |
+
def dino_resnet50(pretrained=True, **kwargs):
|
83 |
+
"""
|
84 |
+
ResNet-50 pre-trained with DINO.
|
85 |
+
Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
|
86 |
+
"""
|
87 |
+
model = resnet50(pretrained=False, **kwargs)
|
88 |
+
model.fc = torch.nn.Identity()
|
89 |
+
if pretrained:
|
90 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
91 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
|
92 |
+
map_location="cpu",
|
93 |
+
)
|
94 |
+
model.load_state_dict(state_dict, strict=False)
|
95 |
+
return model
|
96 |
+
|
97 |
+
|
98 |
+
def dino_xcit_small_12_p16(pretrained=True, **kwargs):
|
99 |
+
"""
|
100 |
+
XCiT-Small-12/16 pre-trained with DINO.
|
101 |
+
"""
|
102 |
+
model = torch.hub.load(
|
103 |
+
"facebookresearch/xcit:main", "xcit_small_12_p16", num_classes=0, **kwargs
|
104 |
+
)
|
105 |
+
if pretrained:
|
106 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
107 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth",
|
108 |
+
map_location="cpu",
|
109 |
+
)
|
110 |
+
model.load_state_dict(state_dict, strict=True)
|
111 |
+
return model
|
112 |
+
|
113 |
+
|
114 |
+
def dino_xcit_small_12_p8(pretrained=True, **kwargs):
|
115 |
+
"""
|
116 |
+
XCiT-Small-12/8 pre-trained with DINO.
|
117 |
+
"""
|
118 |
+
model = torch.hub.load(
|
119 |
+
"facebookresearch/xcit:main", "xcit_small_12_p8", num_classes=0, **kwargs
|
120 |
+
)
|
121 |
+
if pretrained:
|
122 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
123 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth",
|
124 |
+
map_location="cpu",
|
125 |
+
)
|
126 |
+
model.load_state_dict(state_dict, strict=True)
|
127 |
+
return model
|
128 |
+
|
129 |
+
|
130 |
+
def dino_xcit_medium_24_p16(pretrained=True, **kwargs):
|
131 |
+
"""
|
132 |
+
XCiT-Medium-24/16 pre-trained with DINO.
|
133 |
+
"""
|
134 |
+
model = torch.hub.load(
|
135 |
+
"facebookresearch/xcit:main", "xcit_medium_24_p16", num_classes=0, **kwargs
|
136 |
+
)
|
137 |
+
if pretrained:
|
138 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
139 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
|
140 |
+
map_location="cpu",
|
141 |
+
)
|
142 |
+
model.load_state_dict(state_dict, strict=True)
|
143 |
+
return model
|
144 |
+
|
145 |
+
|
146 |
+
def dino_xcit_medium_24_p8(pretrained=True, **kwargs):
|
147 |
+
"""
|
148 |
+
XCiT-Medium-24/8 pre-trained with DINO.
|
149 |
+
"""
|
150 |
+
model = torch.hub.load(
|
151 |
+
"facebookresearch/xcit:main", "xcit_medium_24_p8", num_classes=0, **kwargs
|
152 |
+
)
|
153 |
+
if pretrained:
|
154 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
155 |
+
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth",
|
156 |
+
map_location="cpu",
|
157 |
+
)
|
158 |
+
model.load_state_dict(state_dict, strict=True)
|
159 |
+
return model
|
dino/main_dino.py
ADDED
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
import datetime
|
18 |
+
import time
|
19 |
+
import math
|
20 |
+
import json
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
from PIL import Image
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.distributed as dist
|
28 |
+
import torch.backends.cudnn as cudnn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torchvision import datasets, transforms
|
31 |
+
from torchvision import models as torchvision_models
|
32 |
+
|
33 |
+
import utils
|
34 |
+
import vision_transformer as vits
|
35 |
+
from vision_transformer import DINOHead
|
36 |
+
|
37 |
+
torchvision_archs = sorted(
|
38 |
+
name
|
39 |
+
for name in torchvision_models.__dict__
|
40 |
+
if name.islower()
|
41 |
+
and not name.startswith("__")
|
42 |
+
and callable(torchvision_models.__dict__[name])
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def get_args_parser():
|
47 |
+
parser = argparse.ArgumentParser("DINO", add_help=False)
|
48 |
+
|
49 |
+
# Model parameters
|
50 |
+
parser.add_argument(
|
51 |
+
"--arch",
|
52 |
+
default="vit_small",
|
53 |
+
type=str,
|
54 |
+
choices=["vit_tiny", "vit_small", "vit_base", "xcit", "deit_tiny", "deit_small"]
|
55 |
+
+ torchvision_archs
|
56 |
+
+ torch.hub.list("facebookresearch/xcit:main"),
|
57 |
+
help="""Name of architecture to train. For quick experiments with ViTs,
|
58 |
+
we recommend using vit_tiny or vit_small.""",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--patch_size",
|
62 |
+
default=16,
|
63 |
+
type=int,
|
64 |
+
help="""Size in pixels
|
65 |
+
of input square patches - default 16 (for 16x16 patches). Using smaller
|
66 |
+
values leads to better performance but requires more memory. Applies only
|
67 |
+
for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling
|
68 |
+
mixed precision training (--use_fp16 false) to avoid unstabilities.""",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--out_dim",
|
72 |
+
default=65536,
|
73 |
+
type=int,
|
74 |
+
help="""Dimensionality of
|
75 |
+
the DINO head output. For complex and large datasets large values (like 65k) work well.""",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--norm_last_layer",
|
79 |
+
default=True,
|
80 |
+
type=utils.bool_flag,
|
81 |
+
help="""Whether or not to weight normalize the last layer of the DINO head.
|
82 |
+
Not normalizing leads to better performance but can make the training unstable.
|
83 |
+
In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--momentum_teacher",
|
87 |
+
default=0.996,
|
88 |
+
type=float,
|
89 |
+
help="""Base EMA
|
90 |
+
parameter for teacher update. The value is increased to 1 during training with cosine schedule.
|
91 |
+
We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--use_bn_in_head",
|
95 |
+
default=False,
|
96 |
+
type=utils.bool_flag,
|
97 |
+
help="Whether to use batch normalizations in projection head (Default: False)",
|
98 |
+
)
|
99 |
+
|
100 |
+
# Temperature teacher parameters
|
101 |
+
parser.add_argument(
|
102 |
+
"--warmup_teacher_temp",
|
103 |
+
default=0.04,
|
104 |
+
type=float,
|
105 |
+
help="""Initial value for the teacher temperature: 0.04 works well in most cases.
|
106 |
+
Try decreasing it if the training loss does not decrease.""",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--teacher_temp",
|
110 |
+
default=0.04,
|
111 |
+
type=float,
|
112 |
+
help="""Final value (after linear warmup)
|
113 |
+
of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
|
114 |
+
starting with the default value of 0.04 and increase this slightly if needed.""",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--warmup_teacher_temp_epochs",
|
118 |
+
default=0,
|
119 |
+
type=int,
|
120 |
+
help="Number of warmup epochs for the teacher temperature (Default: 30).",
|
121 |
+
)
|
122 |
+
|
123 |
+
# Training/Optimization parameters
|
124 |
+
parser.add_argument(
|
125 |
+
"--use_fp16",
|
126 |
+
type=utils.bool_flag,
|
127 |
+
default=True,
|
128 |
+
help="""Whether or not
|
129 |
+
to use half precision for training. Improves training time and memory requirements,
|
130 |
+
but can provoke instability and slight decay of performance. We recommend disabling
|
131 |
+
mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--weight_decay",
|
135 |
+
type=float,
|
136 |
+
default=0.04,
|
137 |
+
help="""Initial value of the
|
138 |
+
weight decay. With ViT, a smaller value at the beginning of training works well.""",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--weight_decay_end",
|
142 |
+
type=float,
|
143 |
+
default=0.4,
|
144 |
+
help="""Final value of the
|
145 |
+
weight decay. We use a cosine schedule for WD and using a larger decay by
|
146 |
+
the end of training improves performance for ViTs.""",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--clip_grad",
|
150 |
+
type=float,
|
151 |
+
default=3.0,
|
152 |
+
help="""Maximal parameter
|
153 |
+
gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
|
154 |
+
help optimization for larger ViT architectures. 0 for disabling.""",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--batch_size_per_gpu",
|
158 |
+
default=64,
|
159 |
+
type=int,
|
160 |
+
help="Per-GPU batch-size : number of distinct images loaded on one GPU.",
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--epochs", default=100, type=int, help="Number of epochs of training."
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--freeze_last_layer",
|
167 |
+
default=1,
|
168 |
+
type=int,
|
169 |
+
help="""Number of epochs
|
170 |
+
during which we keep the output layer fixed. Typically doing so during
|
171 |
+
the first epoch helps training. Try increasing this value if the loss does not decrease.""",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--lr",
|
175 |
+
default=0.0005,
|
176 |
+
type=float,
|
177 |
+
help="""Learning rate at the end of
|
178 |
+
linear warmup (highest LR used during training). The learning rate is linearly scaled
|
179 |
+
with the batch size, and specified here for a reference batch size of 256.""",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--warmup_epochs",
|
183 |
+
default=10,
|
184 |
+
type=int,
|
185 |
+
help="Number of epochs for the linear learning-rate warm up.",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--min_lr",
|
189 |
+
type=float,
|
190 |
+
default=1e-6,
|
191 |
+
help="""Target LR at the
|
192 |
+
end of optimization. We use a cosine LR schedule with linear warmup.""",
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--optimizer",
|
196 |
+
default="adamw",
|
197 |
+
type=str,
|
198 |
+
choices=["adamw", "sgd", "lars"],
|
199 |
+
help="""Type of optimizer. We recommend using adamw with ViTs.""",
|
200 |
+
)
|
201 |
+
parser.add_argument(
|
202 |
+
"--drop_path_rate", type=float, default=0.1, help="stochastic depth rate"
|
203 |
+
)
|
204 |
+
|
205 |
+
# Multi-crop parameters
|
206 |
+
parser.add_argument(
|
207 |
+
"--global_crops_scale",
|
208 |
+
type=float,
|
209 |
+
nargs="+",
|
210 |
+
default=(0.4, 1.0),
|
211 |
+
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
212 |
+
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
|
213 |
+
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--local_crops_number",
|
217 |
+
type=int,
|
218 |
+
default=8,
|
219 |
+
help="""Number of small
|
220 |
+
local views to generate. Set this parameter to 0 to disable multi-crop training.
|
221 |
+
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """,
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--local_crops_scale",
|
225 |
+
type=float,
|
226 |
+
nargs="+",
|
227 |
+
default=(0.05, 0.4),
|
228 |
+
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
229 |
+
Used for small local view cropping of multi-crop.""",
|
230 |
+
)
|
231 |
+
|
232 |
+
# Misc
|
233 |
+
parser.add_argument(
|
234 |
+
"--data_path",
|
235 |
+
default="/path/to/imagenet/train/",
|
236 |
+
type=str,
|
237 |
+
help="Please specify path to the ImageNet training data.",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--output_dir", default=".", type=str, help="Path to save logs and checkpoints."
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--saveckp_freq", default=20, type=int, help="Save checkpoint every x epochs."
|
244 |
+
)
|
245 |
+
parser.add_argument("--seed", default=0, type=int, help="Random seed.")
|
246 |
+
parser.add_argument(
|
247 |
+
"--num_workers",
|
248 |
+
default=10,
|
249 |
+
type=int,
|
250 |
+
help="Number of data loading workers per GPU.",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--dist_url",
|
254 |
+
default="env://",
|
255 |
+
type=str,
|
256 |
+
help="""url used to set up
|
257 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""",
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--local_rank",
|
261 |
+
default=0,
|
262 |
+
type=int,
|
263 |
+
help="Please ignore and do not set this argument.",
|
264 |
+
)
|
265 |
+
return parser
|
266 |
+
|
267 |
+
|
268 |
+
def train_dino(args):
|
269 |
+
utils.init_distributed_mode(args)
|
270 |
+
utils.fix_random_seeds(args.seed)
|
271 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
272 |
+
print(
|
273 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
274 |
+
)
|
275 |
+
cudnn.benchmark = True
|
276 |
+
|
277 |
+
# ============ preparing data ... ============
|
278 |
+
transform = DataAugmentationDINO(
|
279 |
+
args.global_crops_scale,
|
280 |
+
args.local_crops_scale,
|
281 |
+
args.local_crops_number,
|
282 |
+
)
|
283 |
+
dataset = datasets.ImageFolder(args.data_path, transform=transform)
|
284 |
+
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
285 |
+
data_loader = torch.utils.data.DataLoader(
|
286 |
+
dataset,
|
287 |
+
sampler=sampler,
|
288 |
+
batch_size=args.batch_size_per_gpu,
|
289 |
+
num_workers=args.num_workers,
|
290 |
+
pin_memory=True,
|
291 |
+
drop_last=True,
|
292 |
+
)
|
293 |
+
print(f"Data loaded: there are {len(dataset)} images.")
|
294 |
+
|
295 |
+
# ============ building student and teacher networks ... ============
|
296 |
+
# we changed the name DeiT-S for ViT-S to avoid confusions
|
297 |
+
args.arch = args.arch.replace("deit", "vit")
|
298 |
+
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
|
299 |
+
if args.arch in vits.__dict__.keys():
|
300 |
+
student = vits.__dict__[args.arch](
|
301 |
+
patch_size=args.patch_size,
|
302 |
+
drop_path_rate=args.drop_path_rate, # stochastic depth
|
303 |
+
)
|
304 |
+
teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
|
305 |
+
embed_dim = student.embed_dim
|
306 |
+
# if the network is a XCiT
|
307 |
+
elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
|
308 |
+
student = torch.hub.load(
|
309 |
+
"facebookresearch/xcit:main",
|
310 |
+
args.arch,
|
311 |
+
pretrained=False,
|
312 |
+
drop_path_rate=args.drop_path_rate,
|
313 |
+
)
|
314 |
+
teacher = torch.hub.load(
|
315 |
+
"facebookresearch/xcit:main", args.arch, pretrained=False
|
316 |
+
)
|
317 |
+
embed_dim = student.embed_dim
|
318 |
+
# otherwise, we check if the architecture is in torchvision models
|
319 |
+
elif args.arch in torchvision_models.__dict__.keys():
|
320 |
+
student = torchvision_models.__dict__[args.arch]()
|
321 |
+
teacher = torchvision_models.__dict__[args.arch]()
|
322 |
+
embed_dim = student.fc.weight.shape[1]
|
323 |
+
else:
|
324 |
+
print(f"Unknow architecture: {args.arch}")
|
325 |
+
|
326 |
+
# multi-crop wrapper handles forward with inputs of different resolutions
|
327 |
+
student = utils.MultiCropWrapper(
|
328 |
+
student,
|
329 |
+
DINOHead(
|
330 |
+
embed_dim,
|
331 |
+
args.out_dim,
|
332 |
+
use_bn=args.use_bn_in_head,
|
333 |
+
norm_last_layer=args.norm_last_layer,
|
334 |
+
),
|
335 |
+
)
|
336 |
+
teacher = utils.MultiCropWrapper(
|
337 |
+
teacher,
|
338 |
+
DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
|
339 |
+
)
|
340 |
+
# move networks to gpu
|
341 |
+
student, teacher = student.cuda(), teacher.cuda()
|
342 |
+
# synchronize batch norms (if any)
|
343 |
+
if utils.has_batchnorms(student):
|
344 |
+
student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
|
345 |
+
teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
|
346 |
+
|
347 |
+
# we need DDP wrapper to have synchro batch norms working...
|
348 |
+
teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
|
349 |
+
teacher_without_ddp = teacher.module
|
350 |
+
else:
|
351 |
+
# teacher_without_ddp and teacher are the same thing
|
352 |
+
teacher_without_ddp = teacher
|
353 |
+
student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
|
354 |
+
# teacher and student start with the same weights
|
355 |
+
teacher_without_ddp.load_state_dict(student.module.state_dict())
|
356 |
+
# there is no backpropagation through the teacher, so no need for gradients
|
357 |
+
for p in teacher.parameters():
|
358 |
+
p.requires_grad = False
|
359 |
+
print(f"Student and Teacher are built: they are both {args.arch} network.")
|
360 |
+
|
361 |
+
# ============ preparing loss ... ============
|
362 |
+
dino_loss = DINOLoss(
|
363 |
+
args.out_dim,
|
364 |
+
args.local_crops_number
|
365 |
+
+ 2, # total number of crops = 2 global crops + local_crops_number
|
366 |
+
args.warmup_teacher_temp,
|
367 |
+
args.teacher_temp,
|
368 |
+
args.warmup_teacher_temp_epochs,
|
369 |
+
args.epochs,
|
370 |
+
).cuda()
|
371 |
+
|
372 |
+
# ============ preparing optimizer ... ============
|
373 |
+
params_groups = utils.get_params_groups(student)
|
374 |
+
if args.optimizer == "adamw":
|
375 |
+
optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
|
376 |
+
elif args.optimizer == "sgd":
|
377 |
+
optimizer = torch.optim.SGD(
|
378 |
+
params_groups, lr=0, momentum=0.9
|
379 |
+
) # lr is set by scheduler
|
380 |
+
elif args.optimizer == "lars":
|
381 |
+
optimizer = utils.LARS(params_groups) # to use with convnet and large batches
|
382 |
+
# for mixed precision training
|
383 |
+
fp16_scaler = None
|
384 |
+
if args.use_fp16:
|
385 |
+
fp16_scaler = torch.cuda.amp.GradScaler()
|
386 |
+
|
387 |
+
# ============ init schedulers ... ============
|
388 |
+
lr_schedule = utils.cosine_scheduler(
|
389 |
+
args.lr
|
390 |
+
* (args.batch_size_per_gpu * utils.get_world_size())
|
391 |
+
/ 256.0, # linear scaling rule
|
392 |
+
args.min_lr,
|
393 |
+
args.epochs,
|
394 |
+
len(data_loader),
|
395 |
+
warmup_epochs=args.warmup_epochs,
|
396 |
+
)
|
397 |
+
wd_schedule = utils.cosine_scheduler(
|
398 |
+
args.weight_decay,
|
399 |
+
args.weight_decay_end,
|
400 |
+
args.epochs,
|
401 |
+
len(data_loader),
|
402 |
+
)
|
403 |
+
# momentum parameter is increased to 1. during training with a cosine schedule
|
404 |
+
momentum_schedule = utils.cosine_scheduler(
|
405 |
+
args.momentum_teacher, 1, args.epochs, len(data_loader)
|
406 |
+
)
|
407 |
+
print(f"Loss, optimizer and schedulers ready.")
|
408 |
+
|
409 |
+
# ============ optionally resume training ... ============
|
410 |
+
to_restore = {"epoch": 0}
|
411 |
+
utils.restart_from_checkpoint(
|
412 |
+
os.path.join(args.output_dir, "checkpoint.pth"),
|
413 |
+
run_variables=to_restore,
|
414 |
+
student=student,
|
415 |
+
teacher=teacher,
|
416 |
+
optimizer=optimizer,
|
417 |
+
fp16_scaler=fp16_scaler,
|
418 |
+
dino_loss=dino_loss,
|
419 |
+
)
|
420 |
+
start_epoch = to_restore["epoch"]
|
421 |
+
|
422 |
+
start_time = time.time()
|
423 |
+
print("Starting DINO training !")
|
424 |
+
for epoch in range(start_epoch, args.epochs):
|
425 |
+
data_loader.sampler.set_epoch(epoch)
|
426 |
+
|
427 |
+
# ============ training one epoch of DINO ... ============
|
428 |
+
train_stats = train_one_epoch(
|
429 |
+
student,
|
430 |
+
teacher,
|
431 |
+
teacher_without_ddp,
|
432 |
+
dino_loss,
|
433 |
+
data_loader,
|
434 |
+
optimizer,
|
435 |
+
lr_schedule,
|
436 |
+
wd_schedule,
|
437 |
+
momentum_schedule,
|
438 |
+
epoch,
|
439 |
+
fp16_scaler,
|
440 |
+
args,
|
441 |
+
)
|
442 |
+
|
443 |
+
# ============ writing logs ... ============
|
444 |
+
save_dict = {
|
445 |
+
"student": student.state_dict(),
|
446 |
+
"teacher": teacher.state_dict(),
|
447 |
+
"optimizer": optimizer.state_dict(),
|
448 |
+
"epoch": epoch + 1,
|
449 |
+
"args": args,
|
450 |
+
"dino_loss": dino_loss.state_dict(),
|
451 |
+
}
|
452 |
+
if fp16_scaler is not None:
|
453 |
+
save_dict["fp16_scaler"] = fp16_scaler.state_dict()
|
454 |
+
utils.save_on_master(save_dict, os.path.join(args.output_dir, "checkpoint.pth"))
|
455 |
+
if args.saveckp_freq and epoch % args.saveckp_freq == 0:
|
456 |
+
utils.save_on_master(
|
457 |
+
save_dict, os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth")
|
458 |
+
)
|
459 |
+
log_stats = {
|
460 |
+
**{f"train_{k}": v for k, v in train_stats.items()},
|
461 |
+
"epoch": epoch,
|
462 |
+
}
|
463 |
+
if utils.is_main_process():
|
464 |
+
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
465 |
+
f.write(json.dumps(log_stats) + "\n")
|
466 |
+
total_time = time.time() - start_time
|
467 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
468 |
+
print("Training time {}".format(total_time_str))
|
469 |
+
|
470 |
+
|
471 |
+
def train_one_epoch(
|
472 |
+
student,
|
473 |
+
teacher,
|
474 |
+
teacher_without_ddp,
|
475 |
+
dino_loss,
|
476 |
+
data_loader,
|
477 |
+
optimizer,
|
478 |
+
lr_schedule,
|
479 |
+
wd_schedule,
|
480 |
+
momentum_schedule,
|
481 |
+
epoch,
|
482 |
+
fp16_scaler,
|
483 |
+
args,
|
484 |
+
):
|
485 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
486 |
+
header = "Epoch: [{}/{}]".format(epoch, args.epochs)
|
487 |
+
for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
488 |
+
# update weight decay and learning rate according to their schedule
|
489 |
+
it = len(data_loader) * epoch + it # global training iteration
|
490 |
+
for i, param_group in enumerate(optimizer.param_groups):
|
491 |
+
param_group["lr"] = lr_schedule[it]
|
492 |
+
if i == 0: # only the first group is regularized
|
493 |
+
param_group["weight_decay"] = wd_schedule[it]
|
494 |
+
|
495 |
+
# move images to gpu
|
496 |
+
images = [im.cuda(non_blocking=True) for im in images]
|
497 |
+
# teacher and student forward passes + compute dino loss
|
498 |
+
with torch.cuda.amp.autocast(fp16_scaler is not None):
|
499 |
+
teacher_output = teacher(
|
500 |
+
images[:2]
|
501 |
+
) # only the 2 global views pass through the teacher
|
502 |
+
student_output = student(images)
|
503 |
+
loss = dino_loss(student_output, teacher_output, epoch)
|
504 |
+
|
505 |
+
if not math.isfinite(loss.item()):
|
506 |
+
print("Loss is {}, stopping training".format(loss.item()), force=True)
|
507 |
+
sys.exit(1)
|
508 |
+
|
509 |
+
# student update
|
510 |
+
optimizer.zero_grad()
|
511 |
+
param_norms = None
|
512 |
+
if fp16_scaler is None:
|
513 |
+
loss.backward()
|
514 |
+
if args.clip_grad:
|
515 |
+
param_norms = utils.clip_gradients(student, args.clip_grad)
|
516 |
+
utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
|
517 |
+
optimizer.step()
|
518 |
+
else:
|
519 |
+
fp16_scaler.scale(loss).backward()
|
520 |
+
if args.clip_grad:
|
521 |
+
fp16_scaler.unscale_(
|
522 |
+
optimizer
|
523 |
+
) # unscale the gradients of optimizer's assigned params in-place
|
524 |
+
param_norms = utils.clip_gradients(student, args.clip_grad)
|
525 |
+
utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
|
526 |
+
fp16_scaler.step(optimizer)
|
527 |
+
fp16_scaler.update()
|
528 |
+
|
529 |
+
# EMA update for the teacher
|
530 |
+
with torch.no_grad():
|
531 |
+
m = momentum_schedule[it] # momentum parameter
|
532 |
+
for param_q, param_k in zip(
|
533 |
+
student.module.parameters(), teacher_without_ddp.parameters()
|
534 |
+
):
|
535 |
+
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
|
536 |
+
|
537 |
+
# logging
|
538 |
+
torch.cuda.synchronize()
|
539 |
+
metric_logger.update(loss=loss.item())
|
540 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
541 |
+
metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
|
542 |
+
# gather the stats from all processes
|
543 |
+
metric_logger.synchronize_between_processes()
|
544 |
+
print("Averaged stats:", metric_logger)
|
545 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
546 |
+
|
547 |
+
|
548 |
+
class DINOLoss(nn.Module):
|
549 |
+
def __init__(
|
550 |
+
self,
|
551 |
+
out_dim,
|
552 |
+
ncrops,
|
553 |
+
warmup_teacher_temp,
|
554 |
+
teacher_temp,
|
555 |
+
warmup_teacher_temp_epochs,
|
556 |
+
nepochs,
|
557 |
+
student_temp=0.1,
|
558 |
+
center_momentum=0.9,
|
559 |
+
):
|
560 |
+
super().__init__()
|
561 |
+
self.student_temp = student_temp
|
562 |
+
self.center_momentum = center_momentum
|
563 |
+
self.ncrops = ncrops
|
564 |
+
self.register_buffer("center", torch.zeros(1, out_dim))
|
565 |
+
# we apply a warm up for the teacher temperature because
|
566 |
+
# a too high temperature makes the training instable at the beginning
|
567 |
+
self.teacher_temp_schedule = np.concatenate(
|
568 |
+
(
|
569 |
+
np.linspace(
|
570 |
+
warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs
|
571 |
+
),
|
572 |
+
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp,
|
573 |
+
)
|
574 |
+
)
|
575 |
+
|
576 |
+
def forward(self, student_output, teacher_output, epoch):
|
577 |
+
"""
|
578 |
+
Cross-entropy between softmax outputs of the teacher and student networks.
|
579 |
+
"""
|
580 |
+
student_out = student_output / self.student_temp
|
581 |
+
student_out = student_out.chunk(self.ncrops)
|
582 |
+
|
583 |
+
# teacher centering and sharpening
|
584 |
+
temp = self.teacher_temp_schedule[epoch]
|
585 |
+
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
|
586 |
+
teacher_out = teacher_out.detach().chunk(2)
|
587 |
+
|
588 |
+
total_loss = 0
|
589 |
+
n_loss_terms = 0
|
590 |
+
for iq, q in enumerate(teacher_out):
|
591 |
+
for v in range(len(student_out)):
|
592 |
+
if v == iq:
|
593 |
+
# we skip cases where student and teacher operate on the same view
|
594 |
+
continue
|
595 |
+
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
|
596 |
+
total_loss += loss.mean()
|
597 |
+
n_loss_terms += 1
|
598 |
+
total_loss /= n_loss_terms
|
599 |
+
self.update_center(teacher_output)
|
600 |
+
return total_loss
|
601 |
+
|
602 |
+
@torch.no_grad()
|
603 |
+
def update_center(self, teacher_output):
|
604 |
+
"""
|
605 |
+
Update center used for teacher output.
|
606 |
+
"""
|
607 |
+
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
608 |
+
dist.all_reduce(batch_center)
|
609 |
+
batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
|
610 |
+
|
611 |
+
# ema update
|
612 |
+
self.center = self.center * self.center_momentum + batch_center * (
|
613 |
+
1 - self.center_momentum
|
614 |
+
)
|
615 |
+
|
616 |
+
|
617 |
+
class DataAugmentationDINO(object):
|
618 |
+
def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
|
619 |
+
flip_and_color_jitter = transforms.Compose(
|
620 |
+
[
|
621 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
622 |
+
transforms.RandomApply(
|
623 |
+
[
|
624 |
+
transforms.ColorJitter(
|
625 |
+
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
|
626 |
+
)
|
627 |
+
],
|
628 |
+
p=0.8,
|
629 |
+
),
|
630 |
+
transforms.RandomGrayscale(p=0.2),
|
631 |
+
]
|
632 |
+
)
|
633 |
+
normalize = transforms.Compose(
|
634 |
+
[
|
635 |
+
transforms.ToTensor(),
|
636 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
637 |
+
]
|
638 |
+
)
|
639 |
+
|
640 |
+
# first global crop
|
641 |
+
self.global_transfo1 = transforms.Compose(
|
642 |
+
[
|
643 |
+
transforms.RandomResizedCrop(
|
644 |
+
224, scale=global_crops_scale, interpolation=Image.BICUBIC
|
645 |
+
),
|
646 |
+
flip_and_color_jitter,
|
647 |
+
utils.GaussianBlur(1.0),
|
648 |
+
normalize,
|
649 |
+
]
|
650 |
+
)
|
651 |
+
# second global crop
|
652 |
+
self.global_transfo2 = transforms.Compose(
|
653 |
+
[
|
654 |
+
transforms.RandomResizedCrop(
|
655 |
+
224, scale=global_crops_scale, interpolation=Image.BICUBIC
|
656 |
+
),
|
657 |
+
flip_and_color_jitter,
|
658 |
+
utils.GaussianBlur(0.1),
|
659 |
+
utils.Solarization(0.2),
|
660 |
+
normalize,
|
661 |
+
]
|
662 |
+
)
|
663 |
+
# transformation for the local small crops
|
664 |
+
self.local_crops_number = local_crops_number
|
665 |
+
self.local_transfo = transforms.Compose(
|
666 |
+
[
|
667 |
+
transforms.RandomResizedCrop(
|
668 |
+
96, scale=local_crops_scale, interpolation=Image.BICUBIC
|
669 |
+
),
|
670 |
+
flip_and_color_jitter,
|
671 |
+
utils.GaussianBlur(p=0.5),
|
672 |
+
normalize,
|
673 |
+
]
|
674 |
+
)
|
675 |
+
|
676 |
+
def __call__(self, image):
|
677 |
+
crops = []
|
678 |
+
crops.append(self.global_transfo1(image))
|
679 |
+
crops.append(self.global_transfo2(image))
|
680 |
+
for _ in range(self.local_crops_number):
|
681 |
+
crops.append(self.local_transfo(image))
|
682 |
+
return crops
|
683 |
+
|
684 |
+
|
685 |
+
if __name__ == "__main__":
|
686 |
+
parser = argparse.ArgumentParser("DINO", parents=[get_args_parser()])
|
687 |
+
args = parser.parse_args()
|
688 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
689 |
+
train_dino(args)
|
dino/run_with_submitit.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
A script to run multinode training with submitit.
|
16 |
+
Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
|
17 |
+
"""
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
import uuid
|
21 |
+
from pathlib import Path
|
22 |
+
|
23 |
+
import main_dino
|
24 |
+
import submitit
|
25 |
+
|
26 |
+
|
27 |
+
def parse_args():
|
28 |
+
parser = argparse.ArgumentParser(
|
29 |
+
"Submitit for DINO", parents=[main_dino.get_args_parser()]
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--ngpus", default=8, type=int, help="Number of gpus to request on each node"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--nodes", default=2, type=int, help="Number of nodes to request"
|
36 |
+
)
|
37 |
+
parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
|
38 |
+
|
39 |
+
parser.add_argument(
|
40 |
+
"--partition", default="learnfair", type=str, help="Partition where to submit"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--use_volta32", action="store_true", help="Big models? Use this"
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--comment",
|
47 |
+
default="",
|
48 |
+
type=str,
|
49 |
+
help="Comment to pass to scheduler, e.g. priority message",
|
50 |
+
)
|
51 |
+
return parser.parse_args()
|
52 |
+
|
53 |
+
|
54 |
+
def get_shared_folder() -> Path:
|
55 |
+
user = os.getenv("USER")
|
56 |
+
if Path("/checkpoint/").is_dir():
|
57 |
+
p = Path(f"/checkpoint/{user}/experiments")
|
58 |
+
p.mkdir(exist_ok=True)
|
59 |
+
return p
|
60 |
+
raise RuntimeError("No shared folder available")
|
61 |
+
|
62 |
+
|
63 |
+
def get_init_file():
|
64 |
+
# Init file must not exist, but it's parent dir must exist.
|
65 |
+
os.makedirs(str(get_shared_folder()), exist_ok=True)
|
66 |
+
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
|
67 |
+
if init_file.exists():
|
68 |
+
os.remove(str(init_file))
|
69 |
+
return init_file
|
70 |
+
|
71 |
+
|
72 |
+
class Trainer(object):
|
73 |
+
def __init__(self, args):
|
74 |
+
self.args = args
|
75 |
+
|
76 |
+
def __call__(self):
|
77 |
+
import main_dino
|
78 |
+
|
79 |
+
self._setup_gpu_args()
|
80 |
+
main_dino.train_dino(self.args)
|
81 |
+
|
82 |
+
def checkpoint(self):
|
83 |
+
import os
|
84 |
+
import submitit
|
85 |
+
|
86 |
+
self.args.dist_url = get_init_file().as_uri()
|
87 |
+
print("Requeuing ", self.args)
|
88 |
+
empty_trainer = type(self)(self.args)
|
89 |
+
return submitit.helpers.DelayedSubmission(empty_trainer)
|
90 |
+
|
91 |
+
def _setup_gpu_args(self):
|
92 |
+
import submitit
|
93 |
+
from pathlib import Path
|
94 |
+
|
95 |
+
job_env = submitit.JobEnvironment()
|
96 |
+
self.args.output_dir = Path(
|
97 |
+
str(self.args.output_dir).replace("%j", str(job_env.job_id))
|
98 |
+
)
|
99 |
+
self.args.gpu = job_env.local_rank
|
100 |
+
self.args.rank = job_env.global_rank
|
101 |
+
self.args.world_size = job_env.num_tasks
|
102 |
+
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
args = parse_args()
|
107 |
+
if args.output_dir == "":
|
108 |
+
args.output_dir = get_shared_folder() / "%j"
|
109 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
110 |
+
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
|
111 |
+
|
112 |
+
num_gpus_per_node = args.ngpus
|
113 |
+
nodes = args.nodes
|
114 |
+
timeout_min = args.timeout
|
115 |
+
|
116 |
+
partition = args.partition
|
117 |
+
kwargs = {}
|
118 |
+
if args.use_volta32:
|
119 |
+
kwargs["slurm_constraint"] = "volta32gb"
|
120 |
+
if args.comment:
|
121 |
+
kwargs["slurm_comment"] = args.comment
|
122 |
+
|
123 |
+
executor.update_parameters(
|
124 |
+
mem_gb=40 * num_gpus_per_node,
|
125 |
+
gpus_per_node=num_gpus_per_node,
|
126 |
+
tasks_per_node=num_gpus_per_node, # one task per GPU
|
127 |
+
cpus_per_task=10,
|
128 |
+
nodes=nodes,
|
129 |
+
timeout_min=timeout_min, # max is 60 * 72
|
130 |
+
# Below are cluster dependent parameters
|
131 |
+
slurm_partition=partition,
|
132 |
+
slurm_signal_delay_s=120,
|
133 |
+
**kwargs,
|
134 |
+
)
|
135 |
+
|
136 |
+
executor.update_parameters(name="dino")
|
137 |
+
|
138 |
+
args.dist_url = get_init_file().as_uri()
|
139 |
+
|
140 |
+
trainer = Trainer(args)
|
141 |
+
job = executor.submit(trainer)
|
142 |
+
|
143 |
+
print(f"Submitted job_id: {job.job_id}")
|
144 |
+
print(f"Logs and checkpoints will be saved at: {args.output_dir}")
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
main()
|
dino/utils.py
ADDED
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Misc functions.
|
16 |
+
|
17 |
+
Mostly copy-paste from torchvision references or other public repos like DETR:
|
18 |
+
https://github.com/facebookresearch/detr/blob/master/util/misc.py
|
19 |
+
"""
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import time
|
23 |
+
import math
|
24 |
+
import random
|
25 |
+
import datetime
|
26 |
+
import subprocess
|
27 |
+
from collections import defaultdict, deque
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from torch import nn
|
32 |
+
import torch.distributed as dist
|
33 |
+
from PIL import ImageFilter, ImageOps
|
34 |
+
|
35 |
+
|
36 |
+
class GaussianBlur(object):
|
37 |
+
"""
|
38 |
+
Apply Gaussian Blur to the PIL image.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
|
42 |
+
self.prob = p
|
43 |
+
self.radius_min = radius_min
|
44 |
+
self.radius_max = radius_max
|
45 |
+
|
46 |
+
def __call__(self, img):
|
47 |
+
do_it = random.random() <= self.prob
|
48 |
+
if not do_it:
|
49 |
+
return img
|
50 |
+
|
51 |
+
return img.filter(
|
52 |
+
ImageFilter.GaussianBlur(
|
53 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
class Solarization(object):
|
59 |
+
"""
|
60 |
+
Apply Solarization to the PIL image.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, p):
|
64 |
+
self.p = p
|
65 |
+
|
66 |
+
def __call__(self, img):
|
67 |
+
if random.random() < self.p:
|
68 |
+
return ImageOps.solarize(img)
|
69 |
+
else:
|
70 |
+
return img
|
71 |
+
|
72 |
+
|
73 |
+
def load_pretrained_weights(
|
74 |
+
model, pretrained_weights, checkpoint_key, model_name, patch_size
|
75 |
+
):
|
76 |
+
if os.path.isfile(pretrained_weights):
|
77 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
78 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
79 |
+
print(f"Take key {checkpoint_key} in provided checkpoint dict")
|
80 |
+
state_dict = state_dict[checkpoint_key]
|
81 |
+
# remove `module.` prefix
|
82 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
83 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
84 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
85 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
86 |
+
print(
|
87 |
+
"Pretrained weights found at {} and loaded with msg: {}".format(
|
88 |
+
pretrained_weights, msg
|
89 |
+
)
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
print(
|
93 |
+
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
|
94 |
+
)
|
95 |
+
url = None
|
96 |
+
if model_name == "vit_small" and patch_size == 16:
|
97 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
98 |
+
elif model_name == "vit_small" and patch_size == 8:
|
99 |
+
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
|
100 |
+
elif model_name == "vit_base" and patch_size == 16:
|
101 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
102 |
+
elif model_name == "vit_base" and patch_size == 8:
|
103 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
104 |
+
elif model_name == "xcit_small_12_p16":
|
105 |
+
url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
|
106 |
+
elif model_name == "xcit_small_12_p8":
|
107 |
+
url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
|
108 |
+
elif model_name == "xcit_medium_24_p16":
|
109 |
+
url = (
|
110 |
+
"dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
|
111 |
+
)
|
112 |
+
elif model_name == "xcit_medium_24_p8":
|
113 |
+
url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
|
114 |
+
elif model_name == "resnet50":
|
115 |
+
url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
|
116 |
+
if url is not None:
|
117 |
+
print(
|
118 |
+
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
|
119 |
+
)
|
120 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
121 |
+
url="https://dl.fbaipublicfiles.com/dino/" + url
|
122 |
+
)
|
123 |
+
model.load_state_dict(state_dict, strict=True)
|
124 |
+
else:
|
125 |
+
print(
|
126 |
+
"There is no reference weights available for this model => We use random weights."
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def load_pretrained_linear_weights(linear_classifier, model_name, patch_size):
|
131 |
+
url = None
|
132 |
+
if model_name == "vit_small" and patch_size == 16:
|
133 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
|
134 |
+
elif model_name == "vit_small" and patch_size == 8:
|
135 |
+
url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
|
136 |
+
elif model_name == "vit_base" and patch_size == 16:
|
137 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
|
138 |
+
elif model_name == "vit_base" and patch_size == 8:
|
139 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
|
140 |
+
elif model_name == "resnet50":
|
141 |
+
url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
|
142 |
+
if url is not None:
|
143 |
+
print("We load the reference pretrained linear weights.")
|
144 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
145 |
+
url="https://dl.fbaipublicfiles.com/dino/" + url
|
146 |
+
)["state_dict"]
|
147 |
+
linear_classifier.load_state_dict(state_dict, strict=True)
|
148 |
+
else:
|
149 |
+
print("We use random linear weights.")
|
150 |
+
|
151 |
+
|
152 |
+
def clip_gradients(model, clip):
|
153 |
+
norms = []
|
154 |
+
for name, p in model.named_parameters():
|
155 |
+
if p.grad is not None:
|
156 |
+
param_norm = p.grad.data.norm(2)
|
157 |
+
norms.append(param_norm.item())
|
158 |
+
clip_coef = clip / (param_norm + 1e-6)
|
159 |
+
if clip_coef < 1:
|
160 |
+
p.grad.data.mul_(clip_coef)
|
161 |
+
return norms
|
162 |
+
|
163 |
+
|
164 |
+
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
165 |
+
if epoch >= freeze_last_layer:
|
166 |
+
return
|
167 |
+
for n, p in model.named_parameters():
|
168 |
+
if "last_layer" in n:
|
169 |
+
p.grad = None
|
170 |
+
|
171 |
+
|
172 |
+
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
|
173 |
+
"""
|
174 |
+
Re-start from checkpoint
|
175 |
+
"""
|
176 |
+
if not os.path.isfile(ckp_path):
|
177 |
+
return
|
178 |
+
print("Found checkpoint at {}".format(ckp_path))
|
179 |
+
|
180 |
+
# open checkpoint file
|
181 |
+
checkpoint = torch.load(ckp_path, map_location="cpu")
|
182 |
+
|
183 |
+
# key is what to look for in the checkpoint file
|
184 |
+
# value is the object to load
|
185 |
+
# example: {'state_dict': model}
|
186 |
+
for key, value in kwargs.items():
|
187 |
+
if key in checkpoint and value is not None:
|
188 |
+
try:
|
189 |
+
msg = value.load_state_dict(checkpoint[key], strict=False)
|
190 |
+
print(
|
191 |
+
"=> loaded '{}' from checkpoint '{}' with msg {}".format(
|
192 |
+
key, ckp_path, msg
|
193 |
+
)
|
194 |
+
)
|
195 |
+
except TypeError:
|
196 |
+
try:
|
197 |
+
msg = value.load_state_dict(checkpoint[key])
|
198 |
+
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
|
199 |
+
except ValueError:
|
200 |
+
print(
|
201 |
+
"=> failed to load '{}' from checkpoint: '{}'".format(
|
202 |
+
key, ckp_path
|
203 |
+
)
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
|
207 |
+
|
208 |
+
# re load variable important for the run
|
209 |
+
if run_variables is not None:
|
210 |
+
for var_name in run_variables:
|
211 |
+
if var_name in checkpoint:
|
212 |
+
run_variables[var_name] = checkpoint[var_name]
|
213 |
+
|
214 |
+
|
215 |
+
def cosine_scheduler(
|
216 |
+
base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0
|
217 |
+
):
|
218 |
+
warmup_schedule = np.array([])
|
219 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
220 |
+
if warmup_epochs > 0:
|
221 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
222 |
+
|
223 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
224 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (
|
225 |
+
1 + np.cos(np.pi * iters / len(iters))
|
226 |
+
)
|
227 |
+
|
228 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
229 |
+
assert len(schedule) == epochs * niter_per_ep
|
230 |
+
return schedule
|
231 |
+
|
232 |
+
|
233 |
+
def bool_flag(s):
|
234 |
+
"""
|
235 |
+
Parse boolean arguments from the command line.
|
236 |
+
"""
|
237 |
+
FALSY_STRINGS = {"off", "false", "0"}
|
238 |
+
TRUTHY_STRINGS = {"on", "true", "1"}
|
239 |
+
if s.lower() in FALSY_STRINGS:
|
240 |
+
return False
|
241 |
+
elif s.lower() in TRUTHY_STRINGS:
|
242 |
+
return True
|
243 |
+
else:
|
244 |
+
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
245 |
+
|
246 |
+
|
247 |
+
def fix_random_seeds(seed=31):
|
248 |
+
"""
|
249 |
+
Fix random seeds.
|
250 |
+
"""
|
251 |
+
torch.manual_seed(seed)
|
252 |
+
torch.cuda.manual_seed_all(seed)
|
253 |
+
np.random.seed(seed)
|
254 |
+
|
255 |
+
|
256 |
+
class SmoothedValue(object):
|
257 |
+
"""Track a series of values and provide access to smoothed values over a
|
258 |
+
window or the global series average.
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(self, window_size=20, fmt=None):
|
262 |
+
if fmt is None:
|
263 |
+
fmt = "{median:.6f} ({global_avg:.6f})"
|
264 |
+
self.deque = deque(maxlen=window_size)
|
265 |
+
self.total = 0.0
|
266 |
+
self.count = 0
|
267 |
+
self.fmt = fmt
|
268 |
+
|
269 |
+
def update(self, value, n=1):
|
270 |
+
self.deque.append(value)
|
271 |
+
self.count += n
|
272 |
+
self.total += value * n
|
273 |
+
|
274 |
+
def synchronize_between_processes(self):
|
275 |
+
"""
|
276 |
+
Warning: does not synchronize the deque!
|
277 |
+
"""
|
278 |
+
if not is_dist_avail_and_initialized():
|
279 |
+
return
|
280 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
281 |
+
dist.barrier()
|
282 |
+
dist.all_reduce(t)
|
283 |
+
t = t.tolist()
|
284 |
+
self.count = int(t[0])
|
285 |
+
self.total = t[1]
|
286 |
+
|
287 |
+
@property
|
288 |
+
def median(self):
|
289 |
+
d = torch.tensor(list(self.deque))
|
290 |
+
return d.median().item()
|
291 |
+
|
292 |
+
@property
|
293 |
+
def avg(self):
|
294 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
295 |
+
return d.mean().item()
|
296 |
+
|
297 |
+
@property
|
298 |
+
def global_avg(self):
|
299 |
+
return self.total / self.count
|
300 |
+
|
301 |
+
@property
|
302 |
+
def max(self):
|
303 |
+
return max(self.deque)
|
304 |
+
|
305 |
+
@property
|
306 |
+
def value(self):
|
307 |
+
return self.deque[-1]
|
308 |
+
|
309 |
+
def __str__(self):
|
310 |
+
return self.fmt.format(
|
311 |
+
median=self.median,
|
312 |
+
avg=self.avg,
|
313 |
+
global_avg=self.global_avg,
|
314 |
+
max=self.max,
|
315 |
+
value=self.value,
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def reduce_dict(input_dict, average=True):
|
320 |
+
"""
|
321 |
+
Args:
|
322 |
+
input_dict (dict): all the values will be reduced
|
323 |
+
average (bool): whether to do average or sum
|
324 |
+
Reduce the values in the dictionary from all processes so that all processes
|
325 |
+
have the averaged results. Returns a dict with the same fields as
|
326 |
+
input_dict, after reduction.
|
327 |
+
"""
|
328 |
+
world_size = get_world_size()
|
329 |
+
if world_size < 2:
|
330 |
+
return input_dict
|
331 |
+
with torch.no_grad():
|
332 |
+
names = []
|
333 |
+
values = []
|
334 |
+
# sort the keys so that they are consistent across processes
|
335 |
+
for k in sorted(input_dict.keys()):
|
336 |
+
names.append(k)
|
337 |
+
values.append(input_dict[k])
|
338 |
+
values = torch.stack(values, dim=0)
|
339 |
+
dist.all_reduce(values)
|
340 |
+
if average:
|
341 |
+
values /= world_size
|
342 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
343 |
+
return reduced_dict
|
344 |
+
|
345 |
+
|
346 |
+
class MetricLogger(object):
|
347 |
+
def __init__(self, delimiter="\t"):
|
348 |
+
self.meters = defaultdict(SmoothedValue)
|
349 |
+
self.delimiter = delimiter
|
350 |
+
|
351 |
+
def update(self, **kwargs):
|
352 |
+
for k, v in kwargs.items():
|
353 |
+
if isinstance(v, torch.Tensor):
|
354 |
+
v = v.item()
|
355 |
+
assert isinstance(v, (float, int))
|
356 |
+
self.meters[k].update(v)
|
357 |
+
|
358 |
+
def __getattr__(self, attr):
|
359 |
+
if attr in self.meters:
|
360 |
+
return self.meters[attr]
|
361 |
+
if attr in self.__dict__:
|
362 |
+
return self.__dict__[attr]
|
363 |
+
raise AttributeError(
|
364 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
365 |
+
)
|
366 |
+
|
367 |
+
def __str__(self):
|
368 |
+
loss_str = []
|
369 |
+
for name, meter in self.meters.items():
|
370 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
371 |
+
return self.delimiter.join(loss_str)
|
372 |
+
|
373 |
+
def synchronize_between_processes(self):
|
374 |
+
for meter in self.meters.values():
|
375 |
+
meter.synchronize_between_processes()
|
376 |
+
|
377 |
+
def add_meter(self, name, meter):
|
378 |
+
self.meters[name] = meter
|
379 |
+
|
380 |
+
def log_every(self, iterable, print_freq, header=None):
|
381 |
+
i = 0
|
382 |
+
if not header:
|
383 |
+
header = ""
|
384 |
+
start_time = time.time()
|
385 |
+
end = time.time()
|
386 |
+
iter_time = SmoothedValue(fmt="{avg:.6f}")
|
387 |
+
data_time = SmoothedValue(fmt="{avg:.6f}")
|
388 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
389 |
+
if torch.cuda.is_available():
|
390 |
+
log_msg = self.delimiter.join(
|
391 |
+
[
|
392 |
+
header,
|
393 |
+
"[{0" + space_fmt + "}/{1}]",
|
394 |
+
"eta: {eta}",
|
395 |
+
"{meters}",
|
396 |
+
"time: {time}",
|
397 |
+
"data: {data}",
|
398 |
+
"max mem: {memory:.0f}",
|
399 |
+
]
|
400 |
+
)
|
401 |
+
else:
|
402 |
+
log_msg = self.delimiter.join(
|
403 |
+
[
|
404 |
+
header,
|
405 |
+
"[{0" + space_fmt + "}/{1}]",
|
406 |
+
"eta: {eta}",
|
407 |
+
"{meters}",
|
408 |
+
"time: {time}",
|
409 |
+
"data: {data}",
|
410 |
+
]
|
411 |
+
)
|
412 |
+
MB = 1024.0 * 1024.0
|
413 |
+
for obj in iterable:
|
414 |
+
data_time.update(time.time() - end)
|
415 |
+
yield obj
|
416 |
+
iter_time.update(time.time() - end)
|
417 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
418 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
419 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
420 |
+
if torch.cuda.is_available():
|
421 |
+
print(
|
422 |
+
log_msg.format(
|
423 |
+
i,
|
424 |
+
len(iterable),
|
425 |
+
eta=eta_string,
|
426 |
+
meters=str(self),
|
427 |
+
time=str(iter_time),
|
428 |
+
data=str(data_time),
|
429 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
430 |
+
)
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
print(
|
434 |
+
log_msg.format(
|
435 |
+
i,
|
436 |
+
len(iterable),
|
437 |
+
eta=eta_string,
|
438 |
+
meters=str(self),
|
439 |
+
time=str(iter_time),
|
440 |
+
data=str(data_time),
|
441 |
+
)
|
442 |
+
)
|
443 |
+
i += 1
|
444 |
+
end = time.time()
|
445 |
+
total_time = time.time() - start_time
|
446 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
447 |
+
print(
|
448 |
+
"{} Total time: {} ({:.6f} s / it)".format(
|
449 |
+
header, total_time_str, total_time / len(iterable)
|
450 |
+
)
|
451 |
+
)
|
452 |
+
|
453 |
+
|
454 |
+
def get_sha():
|
455 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
456 |
+
|
457 |
+
def _run(command):
|
458 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
459 |
+
|
460 |
+
sha = "N/A"
|
461 |
+
diff = "clean"
|
462 |
+
branch = "N/A"
|
463 |
+
try:
|
464 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
465 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
466 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
467 |
+
diff = "has uncommited changes" if diff else "clean"
|
468 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
469 |
+
except Exception:
|
470 |
+
pass
|
471 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
472 |
+
return message
|
473 |
+
|
474 |
+
|
475 |
+
def is_dist_avail_and_initialized():
|
476 |
+
if not dist.is_available():
|
477 |
+
return False
|
478 |
+
if not dist.is_initialized():
|
479 |
+
return False
|
480 |
+
return True
|
481 |
+
|
482 |
+
|
483 |
+
def get_world_size():
|
484 |
+
if not is_dist_avail_and_initialized():
|
485 |
+
return 1
|
486 |
+
return dist.get_world_size()
|
487 |
+
|
488 |
+
|
489 |
+
def get_rank():
|
490 |
+
if not is_dist_avail_and_initialized():
|
491 |
+
return 0
|
492 |
+
return dist.get_rank()
|
493 |
+
|
494 |
+
|
495 |
+
def is_main_process():
|
496 |
+
return get_rank() == 0
|
497 |
+
|
498 |
+
|
499 |
+
def save_on_master(*args, **kwargs):
|
500 |
+
if is_main_process():
|
501 |
+
torch.save(*args, **kwargs)
|
502 |
+
|
503 |
+
|
504 |
+
def setup_for_distributed(is_master):
|
505 |
+
"""
|
506 |
+
This function disables printing when not in master process
|
507 |
+
"""
|
508 |
+
import builtins as __builtin__
|
509 |
+
|
510 |
+
builtin_print = __builtin__.print
|
511 |
+
|
512 |
+
def print(*args, **kwargs):
|
513 |
+
force = kwargs.pop("force", False)
|
514 |
+
if is_master or force:
|
515 |
+
builtin_print(*args, **kwargs)
|
516 |
+
|
517 |
+
__builtin__.print = print
|
518 |
+
|
519 |
+
|
520 |
+
def init_distributed_mode(args):
|
521 |
+
# launched with torch.distributed.launch
|
522 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
523 |
+
args.rank = int(os.environ["RANK"])
|
524 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
525 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
526 |
+
# launched with submitit on a slurm cluster
|
527 |
+
elif "SLURM_PROCID" in os.environ:
|
528 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
529 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
530 |
+
# launched naively with `python main_dino.py`
|
531 |
+
# we manually add MASTER_ADDR and MASTER_PORT to env variables
|
532 |
+
elif torch.cuda.is_available():
|
533 |
+
print("Will run the code on one GPU.")
|
534 |
+
args.rank, args.gpu, args.world_size = 0, 0, 1
|
535 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
536 |
+
os.environ["MASTER_PORT"] = "29500"
|
537 |
+
else:
|
538 |
+
print("Does not support training without GPU.")
|
539 |
+
sys.exit(1)
|
540 |
+
|
541 |
+
dist.init_process_group(
|
542 |
+
backend="nccl",
|
543 |
+
init_method=args.dist_url,
|
544 |
+
world_size=args.world_size,
|
545 |
+
rank=args.rank,
|
546 |
+
)
|
547 |
+
|
548 |
+
torch.cuda.set_device(args.gpu)
|
549 |
+
print(
|
550 |
+
"| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
|
551 |
+
)
|
552 |
+
dist.barrier()
|
553 |
+
setup_for_distributed(args.rank == 0)
|
554 |
+
|
555 |
+
|
556 |
+
def accuracy(output, target, topk=(1,)):
|
557 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
558 |
+
maxk = max(topk)
|
559 |
+
batch_size = target.size(0)
|
560 |
+
_, pred = output.topk(maxk, 1, True, True)
|
561 |
+
pred = pred.t()
|
562 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
563 |
+
return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk]
|
564 |
+
|
565 |
+
|
566 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
567 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
568 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
569 |
+
def norm_cdf(x):
|
570 |
+
# Computes standard normal cumulative distribution function
|
571 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
572 |
+
|
573 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
574 |
+
warnings.warn(
|
575 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
576 |
+
"The distribution of values may be incorrect.",
|
577 |
+
stacklevel=2,
|
578 |
+
)
|
579 |
+
|
580 |
+
with torch.no_grad():
|
581 |
+
# Values are generated by using a truncated uniform distribution and
|
582 |
+
# then using the inverse CDF for the normal distribution.
|
583 |
+
# Get upper and lower cdf values
|
584 |
+
l = norm_cdf((a - mean) / std)
|
585 |
+
u = norm_cdf((b - mean) / std)
|
586 |
+
|
587 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
588 |
+
# [2l-1, 2u-1].
|
589 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
590 |
+
|
591 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
592 |
+
# standard normal
|
593 |
+
tensor.erfinv_()
|
594 |
+
|
595 |
+
# Transform to proper mean, std
|
596 |
+
tensor.mul_(std * math.sqrt(2.0))
|
597 |
+
tensor.add_(mean)
|
598 |
+
|
599 |
+
# Clamp to ensure it's in the proper range
|
600 |
+
tensor.clamp_(min=a, max=b)
|
601 |
+
return tensor
|
602 |
+
|
603 |
+
|
604 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
605 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
606 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
607 |
+
|
608 |
+
|
609 |
+
class LARS(torch.optim.Optimizer):
|
610 |
+
"""
|
611 |
+
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
|
612 |
+
"""
|
613 |
+
|
614 |
+
def __init__(
|
615 |
+
self,
|
616 |
+
params,
|
617 |
+
lr=0,
|
618 |
+
weight_decay=0,
|
619 |
+
momentum=0.9,
|
620 |
+
eta=0.001,
|
621 |
+
weight_decay_filter=None,
|
622 |
+
lars_adaptation_filter=None,
|
623 |
+
):
|
624 |
+
defaults = dict(
|
625 |
+
lr=lr,
|
626 |
+
weight_decay=weight_decay,
|
627 |
+
momentum=momentum,
|
628 |
+
eta=eta,
|
629 |
+
weight_decay_filter=weight_decay_filter,
|
630 |
+
lars_adaptation_filter=lars_adaptation_filter,
|
631 |
+
)
|
632 |
+
super().__init__(params, defaults)
|
633 |
+
|
634 |
+
@torch.no_grad()
|
635 |
+
def step(self):
|
636 |
+
for g in self.param_groups:
|
637 |
+
for p in g["params"]:
|
638 |
+
dp = p.grad
|
639 |
+
|
640 |
+
if dp is None:
|
641 |
+
continue
|
642 |
+
|
643 |
+
if p.ndim != 1:
|
644 |
+
dp = dp.add(p, alpha=g["weight_decay"])
|
645 |
+
|
646 |
+
if p.ndim != 1:
|
647 |
+
param_norm = torch.norm(p)
|
648 |
+
update_norm = torch.norm(dp)
|
649 |
+
one = torch.ones_like(param_norm)
|
650 |
+
q = torch.where(
|
651 |
+
param_norm > 0.0,
|
652 |
+
torch.where(
|
653 |
+
update_norm > 0, (g["eta"] * param_norm / update_norm), one
|
654 |
+
),
|
655 |
+
one,
|
656 |
+
)
|
657 |
+
dp = dp.mul(q)
|
658 |
+
|
659 |
+
param_state = self.state[p]
|
660 |
+
if "mu" not in param_state:
|
661 |
+
param_state["mu"] = torch.zeros_like(p)
|
662 |
+
mu = param_state["mu"]
|
663 |
+
mu.mul_(g["momentum"]).add_(dp)
|
664 |
+
|
665 |
+
p.add_(mu, alpha=-g["lr"])
|
666 |
+
|
667 |
+
|
668 |
+
class MultiCropWrapper(nn.Module):
|
669 |
+
"""
|
670 |
+
Perform forward pass separately on each resolution input.
|
671 |
+
The inputs corresponding to a single resolution are clubbed and single
|
672 |
+
forward is run on the same resolution inputs. Hence we do several
|
673 |
+
forward passes = number of different resolutions used. We then
|
674 |
+
concatenate all the output features and run the head forward on these
|
675 |
+
concatenated features.
|
676 |
+
"""
|
677 |
+
|
678 |
+
def __init__(self, backbone, head):
|
679 |
+
super(MultiCropWrapper, self).__init__()
|
680 |
+
# disable layers dedicated to ImageNet labels classification
|
681 |
+
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
|
682 |
+
self.backbone = backbone
|
683 |
+
self.head = head
|
684 |
+
|
685 |
+
def forward(self, x):
|
686 |
+
# convert to list
|
687 |
+
if not isinstance(x, list):
|
688 |
+
x = [x]
|
689 |
+
idx_crops = torch.cumsum(
|
690 |
+
torch.unique_consecutive(
|
691 |
+
torch.tensor([inp.shape[-1] for inp in x]),
|
692 |
+
return_counts=True,
|
693 |
+
)[1],
|
694 |
+
0,
|
695 |
+
)
|
696 |
+
start_idx, output = 0, torch.empty(0).to(x[0].device)
|
697 |
+
for end_idx in idx_crops:
|
698 |
+
_out = self.backbone(torch.cat(x[start_idx:end_idx]))
|
699 |
+
# The output is a tuple with XCiT model. See:
|
700 |
+
# https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
|
701 |
+
if isinstance(_out, tuple):
|
702 |
+
_out = _out[0]
|
703 |
+
# accumulate outputs
|
704 |
+
output = torch.cat((output, _out))
|
705 |
+
start_idx = end_idx
|
706 |
+
# Run the head forward on the concatenated features.
|
707 |
+
return self.head(output)
|
708 |
+
|
709 |
+
|
710 |
+
def get_params_groups(model):
|
711 |
+
regularized = []
|
712 |
+
not_regularized = []
|
713 |
+
for name, param in model.named_parameters():
|
714 |
+
if not param.requires_grad:
|
715 |
+
continue
|
716 |
+
# we do not regularize biases nor Norm parameters
|
717 |
+
if name.endswith(".bias") or len(param.shape) == 1:
|
718 |
+
not_regularized.append(param)
|
719 |
+
else:
|
720 |
+
regularized.append(param)
|
721 |
+
return [{"params": regularized}, {"params": not_regularized, "weight_decay": 0.0}]
|
722 |
+
|
723 |
+
|
724 |
+
def has_batchnorms(model):
|
725 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
726 |
+
for name, module in model.named_modules():
|
727 |
+
if isinstance(module, bn_types):
|
728 |
+
return True
|
729 |
+
return False
|
730 |
+
|
731 |
+
|
732 |
+
class PCA:
|
733 |
+
"""
|
734 |
+
Class to compute and apply PCA.
|
735 |
+
"""
|
736 |
+
|
737 |
+
def __init__(self, dim=256, whit=0.5):
|
738 |
+
self.dim = dim
|
739 |
+
self.whit = whit
|
740 |
+
self.mean = None
|
741 |
+
|
742 |
+
def train_pca(self, cov):
|
743 |
+
"""
|
744 |
+
Takes a covariance matrix (np.ndarray) as input.
|
745 |
+
"""
|
746 |
+
d, v = np.linalg.eigh(cov)
|
747 |
+
eps = d.max() * 1e-5
|
748 |
+
n_0 = (d < eps).sum()
|
749 |
+
if n_0 > 0:
|
750 |
+
d[d < eps] = eps
|
751 |
+
|
752 |
+
# total energy
|
753 |
+
totenergy = d.sum()
|
754 |
+
|
755 |
+
# sort eigenvectors with eigenvalues order
|
756 |
+
idx = np.argsort(d)[::-1][: self.dim]
|
757 |
+
d = d[idx]
|
758 |
+
v = v[:, idx]
|
759 |
+
|
760 |
+
print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
|
761 |
+
|
762 |
+
# for the whitening
|
763 |
+
d = np.diag(1.0 / d**self.whit)
|
764 |
+
|
765 |
+
# principal components
|
766 |
+
self.dvt = np.dot(d, v.T)
|
767 |
+
|
768 |
+
def apply(self, x):
|
769 |
+
# input is from numpy
|
770 |
+
if isinstance(x, np.ndarray):
|
771 |
+
if self.mean is not None:
|
772 |
+
x -= self.mean
|
773 |
+
return np.dot(self.dvt, x.T).T
|
774 |
+
|
775 |
+
# input is from torch and is on GPU
|
776 |
+
if x.is_cuda:
|
777 |
+
if self.mean is not None:
|
778 |
+
x -= torch.cuda.FloatTensor(self.mean)
|
779 |
+
return torch.mm(
|
780 |
+
torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)
|
781 |
+
).transpose(0, 1)
|
782 |
+
|
783 |
+
# input if from torch, on CPU
|
784 |
+
if self.mean is not None:
|
785 |
+
x -= torch.FloatTensor(self.mean)
|
786 |
+
return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
|
787 |
+
|
788 |
+
|
789 |
+
def compute_ap(ranks, nres):
|
790 |
+
"""
|
791 |
+
Computes average precision for given ranked indexes.
|
792 |
+
Arguments
|
793 |
+
---------
|
794 |
+
ranks : zerro-based ranks of positive images
|
795 |
+
nres : number of positive images
|
796 |
+
Returns
|
797 |
+
-------
|
798 |
+
ap : average precision
|
799 |
+
"""
|
800 |
+
|
801 |
+
# number of images ranked by the system
|
802 |
+
nimgranks = len(ranks)
|
803 |
+
|
804 |
+
# accumulate trapezoids in PR-plot
|
805 |
+
ap = 0
|
806 |
+
|
807 |
+
recall_step = 1.0 / nres
|
808 |
+
|
809 |
+
for j in np.arange(nimgranks):
|
810 |
+
rank = ranks[j]
|
811 |
+
|
812 |
+
if rank == 0:
|
813 |
+
precision_0 = 1.0
|
814 |
+
else:
|
815 |
+
precision_0 = float(j) / rank
|
816 |
+
|
817 |
+
precision_1 = float(j + 1) / (rank + 1)
|
818 |
+
|
819 |
+
ap += (precision_0 + precision_1) * recall_step / 2.0
|
820 |
+
|
821 |
+
return ap
|
822 |
+
|
823 |
+
|
824 |
+
def compute_map(ranks, gnd, kappas=[]):
|
825 |
+
"""
|
826 |
+
Computes the mAP for a given set of returned results.
|
827 |
+
Usage:
|
828 |
+
map = compute_map (ranks, gnd)
|
829 |
+
computes mean average precsion (map) only
|
830 |
+
map, aps, pr, prs = compute_map (ranks, gnd, kappas)
|
831 |
+
computes mean average precision (map), average precision (aps) for each query
|
832 |
+
computes mean precision at kappas (pr), precision at kappas (prs) for each query
|
833 |
+
Notes:
|
834 |
+
1) ranks starts from 0, ranks.shape = db_size X #queries
|
835 |
+
2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
|
836 |
+
3) If there are no positive images for some query, that query is excluded from the evaluation
|
837 |
+
"""
|
838 |
+
|
839 |
+
map = 0.0
|
840 |
+
nq = len(gnd) # number of queries
|
841 |
+
aps = np.zeros(nq)
|
842 |
+
pr = np.zeros(len(kappas))
|
843 |
+
prs = np.zeros((nq, len(kappas)))
|
844 |
+
nempty = 0
|
845 |
+
|
846 |
+
for i in np.arange(nq):
|
847 |
+
qgnd = np.array(gnd[i]["ok"])
|
848 |
+
|
849 |
+
# no positive images, skip from the average
|
850 |
+
if qgnd.shape[0] == 0:
|
851 |
+
aps[i] = float("nan")
|
852 |
+
prs[i, :] = float("nan")
|
853 |
+
nempty += 1
|
854 |
+
continue
|
855 |
+
|
856 |
+
try:
|
857 |
+
qgndj = np.array(gnd[i]["junk"])
|
858 |
+
except:
|
859 |
+
qgndj = np.empty(0)
|
860 |
+
|
861 |
+
# sorted positions of positive and junk images (0 based)
|
862 |
+
pos = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgnd)]
|
863 |
+
junk = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgndj)]
|
864 |
+
|
865 |
+
k = 0
|
866 |
+
ij = 0
|
867 |
+
if len(junk):
|
868 |
+
# decrease positions of positives based on the number of
|
869 |
+
# junk images appearing before them
|
870 |
+
ip = 0
|
871 |
+
while ip < len(pos):
|
872 |
+
while ij < len(junk) and pos[ip] > junk[ij]:
|
873 |
+
k += 1
|
874 |
+
ij += 1
|
875 |
+
pos[ip] = pos[ip] - k
|
876 |
+
ip += 1
|
877 |
+
|
878 |
+
# compute ap
|
879 |
+
ap = compute_ap(pos, len(qgnd))
|
880 |
+
map = map + ap
|
881 |
+
aps[i] = ap
|
882 |
+
|
883 |
+
# compute precision @ k
|
884 |
+
pos += 1 # get it to 1-based
|
885 |
+
for j in np.arange(len(kappas)):
|
886 |
+
kq = min(max(pos), kappas[j])
|
887 |
+
prs[i, j] = (pos <= kq).sum() / kq
|
888 |
+
pr = pr + prs[i, :]
|
889 |
+
|
890 |
+
map = map / (nq - nempty)
|
891 |
+
pr = pr / (nq - nempty)
|
892 |
+
|
893 |
+
return map, aps, pr, prs
|
894 |
+
|
895 |
+
|
896 |
+
def multi_scale(samples, model):
|
897 |
+
v = None
|
898 |
+
for s in [1, 1 / 2 ** (1 / 2), 1 / 2]: # we use 3 different scales
|
899 |
+
if s == 1:
|
900 |
+
inp = samples.clone()
|
901 |
+
else:
|
902 |
+
inp = nn.functional.interpolate(
|
903 |
+
samples, scale_factor=s, mode="bilinear", align_corners=False
|
904 |
+
)
|
905 |
+
feats = model(inp).clone()
|
906 |
+
if v is None:
|
907 |
+
v = feats
|
908 |
+
else:
|
909 |
+
v += feats
|
910 |
+
v /= 3
|
911 |
+
v /= v.norm()
|
912 |
+
return v
|
dino/video_generation.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import glob
|
16 |
+
import sys
|
17 |
+
import argparse
|
18 |
+
import cv2
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torchvision
|
25 |
+
from torchvision import transforms as pth_transforms
|
26 |
+
import numpy as np
|
27 |
+
from PIL import Image
|
28 |
+
|
29 |
+
import utils
|
30 |
+
import vision_transformer as vits
|
31 |
+
|
32 |
+
|
33 |
+
FOURCC = {
|
34 |
+
"mp4": cv2.VideoWriter_fourcc(*"MP4V"),
|
35 |
+
"avi": cv2.VideoWriter_fourcc(*"XVID"),
|
36 |
+
}
|
37 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
38 |
+
|
39 |
+
|
40 |
+
class VideoGenerator:
|
41 |
+
def __init__(self, args):
|
42 |
+
self.args = args
|
43 |
+
# self.model = None
|
44 |
+
# Don't need to load model if you only want a video
|
45 |
+
if not self.args.video_only:
|
46 |
+
self.model = self.__load_model()
|
47 |
+
|
48 |
+
def run(self):
|
49 |
+
if self.args.input_path is None:
|
50 |
+
print(f"Provided input path {self.args.input_path} is non valid.")
|
51 |
+
sys.exit(1)
|
52 |
+
else:
|
53 |
+
if self.args.video_only:
|
54 |
+
self._generate_video_from_images(
|
55 |
+
self.args.input_path, self.args.output_path
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
# If input path exists
|
59 |
+
if os.path.exists(self.args.input_path):
|
60 |
+
# If input is a video file
|
61 |
+
if os.path.isfile(self.args.input_path):
|
62 |
+
frames_folder = os.path.join(self.args.output_path, "frames")
|
63 |
+
attention_folder = os.path.join(
|
64 |
+
self.args.output_path, "attention"
|
65 |
+
)
|
66 |
+
|
67 |
+
os.makedirs(frames_folder, exist_ok=True)
|
68 |
+
os.makedirs(attention_folder, exist_ok=True)
|
69 |
+
|
70 |
+
self._extract_frames_from_video(
|
71 |
+
self.args.input_path, frames_folder
|
72 |
+
)
|
73 |
+
|
74 |
+
self._inference(
|
75 |
+
frames_folder,
|
76 |
+
attention_folder,
|
77 |
+
)
|
78 |
+
|
79 |
+
self._generate_video_from_images(
|
80 |
+
attention_folder, self.args.output_path
|
81 |
+
)
|
82 |
+
|
83 |
+
# If input is a folder of already extracted frames
|
84 |
+
if os.path.isdir(self.args.input_path):
|
85 |
+
attention_folder = os.path.join(
|
86 |
+
self.args.output_path, "attention"
|
87 |
+
)
|
88 |
+
|
89 |
+
os.makedirs(attention_folder, exist_ok=True)
|
90 |
+
|
91 |
+
self._inference(self.args.input_path, attention_folder)
|
92 |
+
|
93 |
+
self._generate_video_from_images(
|
94 |
+
attention_folder, self.args.output_path
|
95 |
+
)
|
96 |
+
|
97 |
+
# If input path doesn't exists
|
98 |
+
else:
|
99 |
+
print(f"Provided input path {self.args.input_path} doesn't exists.")
|
100 |
+
sys.exit(1)
|
101 |
+
|
102 |
+
def _extract_frames_from_video(self, inp: str, out: str):
|
103 |
+
vidcap = cv2.VideoCapture(inp)
|
104 |
+
self.args.fps = vidcap.get(cv2.CAP_PROP_FPS)
|
105 |
+
|
106 |
+
print(f"Video: {inp} ({self.args.fps} fps)")
|
107 |
+
print(f"Extracting frames to {out}")
|
108 |
+
|
109 |
+
success, image = vidcap.read()
|
110 |
+
count = 0
|
111 |
+
while success:
|
112 |
+
cv2.imwrite(
|
113 |
+
os.path.join(out, f"frame-{count:04}.jpg"),
|
114 |
+
image,
|
115 |
+
)
|
116 |
+
success, image = vidcap.read()
|
117 |
+
count += 1
|
118 |
+
|
119 |
+
def _generate_video_from_images(self, inp: str, out: str):
|
120 |
+
img_array = []
|
121 |
+
attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))
|
122 |
+
|
123 |
+
# Get size of the first image
|
124 |
+
with open(attention_images_list[0], "rb") as f:
|
125 |
+
img = Image.open(f)
|
126 |
+
img = img.convert("RGB")
|
127 |
+
size = (img.width, img.height)
|
128 |
+
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
129 |
+
|
130 |
+
print(f"Generating video {size} to {out}")
|
131 |
+
|
132 |
+
for filename in tqdm(attention_images_list[1:]):
|
133 |
+
with open(filename, "rb") as f:
|
134 |
+
img = Image.open(f)
|
135 |
+
img = img.convert("RGB")
|
136 |
+
img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
137 |
+
|
138 |
+
out = cv2.VideoWriter(
|
139 |
+
os.path.join(out, "video." + self.args.video_format),
|
140 |
+
FOURCC[self.args.video_format],
|
141 |
+
self.args.fps,
|
142 |
+
size,
|
143 |
+
)
|
144 |
+
|
145 |
+
for i in range(len(img_array)):
|
146 |
+
out.write(img_array[i])
|
147 |
+
out.release()
|
148 |
+
print("Done")
|
149 |
+
|
150 |
+
def _inference(self, inp: str, out: str):
|
151 |
+
print(f"Generating attention images to {out}")
|
152 |
+
|
153 |
+
for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
|
154 |
+
with open(img_path, "rb") as f:
|
155 |
+
img = Image.open(f)
|
156 |
+
img = img.convert("RGB")
|
157 |
+
|
158 |
+
if self.args.resize is not None:
|
159 |
+
transform = pth_transforms.Compose(
|
160 |
+
[
|
161 |
+
pth_transforms.ToTensor(),
|
162 |
+
pth_transforms.Resize(self.args.resize),
|
163 |
+
pth_transforms.Normalize(
|
164 |
+
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
|
165 |
+
),
|
166 |
+
]
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
transform = pth_transforms.Compose(
|
170 |
+
[
|
171 |
+
pth_transforms.ToTensor(),
|
172 |
+
pth_transforms.Normalize(
|
173 |
+
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
|
174 |
+
),
|
175 |
+
]
|
176 |
+
)
|
177 |
+
|
178 |
+
img = transform(img)
|
179 |
+
|
180 |
+
# make the image divisible by the patch size
|
181 |
+
w, h = (
|
182 |
+
img.shape[1] - img.shape[1] % self.args.patch_size,
|
183 |
+
img.shape[2] - img.shape[2] % self.args.patch_size,
|
184 |
+
)
|
185 |
+
img = img[:, :w, :h].unsqueeze(0)
|
186 |
+
|
187 |
+
w_featmap = img.shape[-2] // self.args.patch_size
|
188 |
+
h_featmap = img.shape[-1] // self.args.patch_size
|
189 |
+
|
190 |
+
attentions = self.model.get_last_selfattention(img.to(DEVICE))
|
191 |
+
|
192 |
+
nh = attentions.shape[1] # number of head
|
193 |
+
|
194 |
+
# we keep only the output patch attention
|
195 |
+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
196 |
+
|
197 |
+
# we keep only a certain percentage of the mass
|
198 |
+
val, idx = torch.sort(attentions)
|
199 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
200 |
+
cumval = torch.cumsum(val, dim=1)
|
201 |
+
th_attn = cumval > (1 - self.args.threshold)
|
202 |
+
idx2 = torch.argsort(idx)
|
203 |
+
for head in range(nh):
|
204 |
+
th_attn[head] = th_attn[head][idx2[head]]
|
205 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
206 |
+
# interpolate
|
207 |
+
th_attn = (
|
208 |
+
nn.functional.interpolate(
|
209 |
+
th_attn.unsqueeze(0),
|
210 |
+
scale_factor=self.args.patch_size,
|
211 |
+
mode="nearest",
|
212 |
+
)[0]
|
213 |
+
.cpu()
|
214 |
+
.numpy()
|
215 |
+
)
|
216 |
+
|
217 |
+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
218 |
+
attentions = (
|
219 |
+
nn.functional.interpolate(
|
220 |
+
attentions.unsqueeze(0),
|
221 |
+
scale_factor=self.args.patch_size,
|
222 |
+
mode="nearest",
|
223 |
+
)[0]
|
224 |
+
.cpu()
|
225 |
+
.numpy()
|
226 |
+
)
|
227 |
+
|
228 |
+
# save attentions heatmaps
|
229 |
+
fname = os.path.join(out, "attn-" + os.path.basename(img_path))
|
230 |
+
plt.imsave(
|
231 |
+
fname=fname,
|
232 |
+
arr=sum(
|
233 |
+
attentions[i] * 1 / attentions.shape[0]
|
234 |
+
for i in range(attentions.shape[0])
|
235 |
+
),
|
236 |
+
cmap="inferno",
|
237 |
+
format="jpg",
|
238 |
+
)
|
239 |
+
|
240 |
+
def __load_model(self):
|
241 |
+
# build model
|
242 |
+
model = vits.__dict__[self.args.arch](
|
243 |
+
patch_size=self.args.patch_size, num_classes=0
|
244 |
+
)
|
245 |
+
for p in model.parameters():
|
246 |
+
p.requires_grad = False
|
247 |
+
model.eval()
|
248 |
+
model.to(DEVICE)
|
249 |
+
|
250 |
+
if os.path.isfile(self.args.pretrained_weights):
|
251 |
+
state_dict = torch.load(self.args.pretrained_weights, map_location="cpu")
|
252 |
+
if (
|
253 |
+
self.args.checkpoint_key is not None
|
254 |
+
and self.args.checkpoint_key in state_dict
|
255 |
+
):
|
256 |
+
print(
|
257 |
+
f"Take key {self.args.checkpoint_key} in provided checkpoint dict"
|
258 |
+
)
|
259 |
+
state_dict = state_dict[self.args.checkpoint_key]
|
260 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
261 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
262 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
263 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
264 |
+
print(
|
265 |
+
"Pretrained weights found at {} and loaded with msg: {}".format(
|
266 |
+
self.args.pretrained_weights, msg
|
267 |
+
)
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
print(
|
271 |
+
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
|
272 |
+
)
|
273 |
+
url = None
|
274 |
+
if self.args.arch == "vit_small" and self.args.patch_size == 16:
|
275 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
276 |
+
elif self.args.arch == "vit_small" and self.args.patch_size == 8:
|
277 |
+
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
|
278 |
+
elif self.args.arch == "vit_base" and self.args.patch_size == 16:
|
279 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
280 |
+
elif self.args.arch == "vit_base" and self.args.patch_size == 8:
|
281 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
282 |
+
if url is not None:
|
283 |
+
print(
|
284 |
+
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
|
285 |
+
)
|
286 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
287 |
+
url="https://dl.fbaipublicfiles.com/dino/" + url
|
288 |
+
)
|
289 |
+
model.load_state_dict(state_dict, strict=True)
|
290 |
+
else:
|
291 |
+
print(
|
292 |
+
"There is no reference weights available for this model => We use random weights."
|
293 |
+
)
|
294 |
+
return model
|
295 |
+
|
296 |
+
|
297 |
+
def parse_args():
|
298 |
+
parser = argparse.ArgumentParser("Generation self-attention video")
|
299 |
+
parser.add_argument(
|
300 |
+
"--arch",
|
301 |
+
default="vit_small",
|
302 |
+
type=str,
|
303 |
+
choices=["vit_tiny", "vit_small", "vit_base"],
|
304 |
+
help="Architecture (support only ViT atm).",
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--patch_size", default=8, type=int, help="Patch resolution of the self.model."
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--pretrained_weights",
|
311 |
+
default="",
|
312 |
+
type=str,
|
313 |
+
help="Path to pretrained weights to load.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--checkpoint_key",
|
317 |
+
default="teacher",
|
318 |
+
type=str,
|
319 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
320 |
+
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--input_path",
|
323 |
+
required=True,
|
324 |
+
type=str,
|
325 |
+
help="""Path to a video file if you want to extract frames
|
326 |
+
or to a folder of images already extracted by yourself.
|
327 |
+
or to a folder of attention images.""",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--output_path",
|
331 |
+
default="./",
|
332 |
+
type=str,
|
333 |
+
help="""Path to store a folder of frames and / or a folder of attention images.
|
334 |
+
and / or a final video. Default to current directory.""",
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
"--threshold",
|
338 |
+
type=float,
|
339 |
+
default=0.6,
|
340 |
+
help="""We visualize masks
|
341 |
+
obtained by thresholding the self-attention maps to keep xx percent of the mass.""",
|
342 |
+
)
|
343 |
+
parser.add_argument(
|
344 |
+
"--resize",
|
345 |
+
default=None,
|
346 |
+
type=int,
|
347 |
+
nargs="+",
|
348 |
+
help="""Apply a resize transformation to input image(s). Use if OOM error.
|
349 |
+
Usage (single or W H): --resize 512, --resize 720 1280""",
|
350 |
+
)
|
351 |
+
parser.add_argument(
|
352 |
+
"--video_only",
|
353 |
+
action="store_true",
|
354 |
+
help="""Use this flag if you only want to generate a video and not all attention images.
|
355 |
+
If used, --input_path must be set to the folder of attention images. Ex: ./attention/""",
|
356 |
+
)
|
357 |
+
parser.add_argument(
|
358 |
+
"--fps",
|
359 |
+
default=30.0,
|
360 |
+
type=float,
|
361 |
+
help="FPS of input / output video. Automatically set if you extract frames from a video.",
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--video_format",
|
365 |
+
default="mp4",
|
366 |
+
type=str,
|
367 |
+
choices=["mp4", "avi"],
|
368 |
+
help="Format of generated video (mp4 or avi).",
|
369 |
+
)
|
370 |
+
|
371 |
+
return parser.parse_args()
|
372 |
+
|
373 |
+
|
374 |
+
if __name__ == "__main__":
|
375 |
+
args = parse_args()
|
376 |
+
|
377 |
+
vg = VideoGenerator(args)
|
378 |
+
vg.run()
|
dino/vision_transformer.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Mostly copy-paste from timm library.
|
16 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
|
24 |
+
from utils import trunc_normal_
|
25 |
+
|
26 |
+
|
27 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
28 |
+
if drop_prob == 0.0 or not training:
|
29 |
+
return x
|
30 |
+
keep_prob = 1 - drop_prob
|
31 |
+
shape = (x.shape[0],) + (1,) * (
|
32 |
+
x.ndim - 1
|
33 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
34 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
35 |
+
random_tensor.floor_() # binarize
|
36 |
+
output = x.div(keep_prob) * random_tensor
|
37 |
+
return output
|
38 |
+
|
39 |
+
|
40 |
+
class DropPath(nn.Module):
|
41 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
42 |
+
|
43 |
+
def __init__(self, drop_prob=None):
|
44 |
+
super(DropPath, self).__init__()
|
45 |
+
self.drop_prob = drop_prob
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return drop_path(x, self.drop_prob, self.training)
|
49 |
+
|
50 |
+
|
51 |
+
class Mlp(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
in_features,
|
55 |
+
hidden_features=None,
|
56 |
+
out_features=None,
|
57 |
+
act_layer=nn.GELU,
|
58 |
+
drop=0.0,
|
59 |
+
):
|
60 |
+
super().__init__()
|
61 |
+
out_features = out_features or in_features
|
62 |
+
hidden_features = hidden_features or in_features
|
63 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
64 |
+
self.act = act_layer()
|
65 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
66 |
+
self.drop = nn.Dropout(drop)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = self.fc1(x)
|
70 |
+
x = self.act(x)
|
71 |
+
x = self.drop(x)
|
72 |
+
x = self.fc2(x)
|
73 |
+
x = self.drop(x)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class Attention(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
dim,
|
81 |
+
num_heads=8,
|
82 |
+
qkv_bias=False,
|
83 |
+
qk_scale=None,
|
84 |
+
attn_drop=0.0,
|
85 |
+
proj_drop=0.0,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
self.num_heads = num_heads
|
89 |
+
head_dim = dim // num_heads
|
90 |
+
self.scale = qk_scale or head_dim**-0.5
|
91 |
+
|
92 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
93 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
94 |
+
self.proj = nn.Linear(dim, dim)
|
95 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
B, N, C = x.shape
|
99 |
+
qkv = (
|
100 |
+
self.qkv(x)
|
101 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
102 |
+
.permute(2, 0, 3, 1, 4)
|
103 |
+
)
|
104 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
105 |
+
|
106 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
107 |
+
attn = attn.softmax(dim=-1)
|
108 |
+
attn = self.attn_drop(attn)
|
109 |
+
|
110 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
111 |
+
x = self.proj(x)
|
112 |
+
x = self.proj_drop(x)
|
113 |
+
return x, attn
|
114 |
+
|
115 |
+
|
116 |
+
class Block(nn.Module):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
dim,
|
120 |
+
num_heads,
|
121 |
+
mlp_ratio=4.0,
|
122 |
+
qkv_bias=False,
|
123 |
+
qk_scale=None,
|
124 |
+
drop=0.0,
|
125 |
+
attn_drop=0.0,
|
126 |
+
drop_path=0.0,
|
127 |
+
act_layer=nn.GELU,
|
128 |
+
norm_layer=nn.LayerNorm,
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
self.norm1 = norm_layer(dim)
|
132 |
+
self.attn = Attention(
|
133 |
+
dim,
|
134 |
+
num_heads=num_heads,
|
135 |
+
qkv_bias=qkv_bias,
|
136 |
+
qk_scale=qk_scale,
|
137 |
+
attn_drop=attn_drop,
|
138 |
+
proj_drop=drop,
|
139 |
+
)
|
140 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
141 |
+
self.norm2 = norm_layer(dim)
|
142 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
143 |
+
self.mlp = Mlp(
|
144 |
+
in_features=dim,
|
145 |
+
hidden_features=mlp_hidden_dim,
|
146 |
+
act_layer=act_layer,
|
147 |
+
drop=drop,
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x, return_attention=False):
|
151 |
+
y, attn = self.attn(self.norm1(x))
|
152 |
+
if return_attention:
|
153 |
+
return attn
|
154 |
+
x = x + self.drop_path(y)
|
155 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class PatchEmbed(nn.Module):
|
160 |
+
"""Image to Patch Embedding"""
|
161 |
+
|
162 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
163 |
+
super().__init__()
|
164 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
165 |
+
self.img_size = img_size
|
166 |
+
self.patch_size = patch_size
|
167 |
+
self.num_patches = num_patches
|
168 |
+
|
169 |
+
self.proj = nn.Conv2d(
|
170 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
171 |
+
)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
B, C, H, W = x.shape
|
175 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
176 |
+
return x
|
177 |
+
|
178 |
+
|
179 |
+
class VisionTransformer(nn.Module):
|
180 |
+
"""Vision Transformer"""
|
181 |
+
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
img_size=[224],
|
185 |
+
patch_size=16,
|
186 |
+
in_chans=3,
|
187 |
+
num_classes=0,
|
188 |
+
embed_dim=768,
|
189 |
+
depth=12,
|
190 |
+
num_heads=12,
|
191 |
+
mlp_ratio=4.0,
|
192 |
+
qkv_bias=False,
|
193 |
+
qk_scale=None,
|
194 |
+
drop_rate=0.0,
|
195 |
+
attn_drop_rate=0.0,
|
196 |
+
drop_path_rate=0.0,
|
197 |
+
norm_layer=nn.LayerNorm,
|
198 |
+
**kwargs
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
self.num_features = self.embed_dim = embed_dim
|
202 |
+
|
203 |
+
self.patch_embed = PatchEmbed(
|
204 |
+
img_size=img_size[0],
|
205 |
+
patch_size=patch_size,
|
206 |
+
in_chans=in_chans,
|
207 |
+
embed_dim=embed_dim,
|
208 |
+
)
|
209 |
+
num_patches = self.patch_embed.num_patches
|
210 |
+
|
211 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
212 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
213 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
214 |
+
|
215 |
+
dpr = [
|
216 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
217 |
+
] # stochastic depth decay rule
|
218 |
+
self.blocks = nn.ModuleList(
|
219 |
+
[
|
220 |
+
Block(
|
221 |
+
dim=embed_dim,
|
222 |
+
num_heads=num_heads,
|
223 |
+
mlp_ratio=mlp_ratio,
|
224 |
+
qkv_bias=qkv_bias,
|
225 |
+
qk_scale=qk_scale,
|
226 |
+
drop=drop_rate,
|
227 |
+
attn_drop=attn_drop_rate,
|
228 |
+
drop_path=dpr[i],
|
229 |
+
norm_layer=norm_layer,
|
230 |
+
)
|
231 |
+
for i in range(depth)
|
232 |
+
]
|
233 |
+
)
|
234 |
+
self.norm = norm_layer(embed_dim)
|
235 |
+
|
236 |
+
# Classifier head
|
237 |
+
self.head = (
|
238 |
+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
239 |
+
)
|
240 |
+
|
241 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
242 |
+
trunc_normal_(self.cls_token, std=0.02)
|
243 |
+
self.apply(self._init_weights)
|
244 |
+
|
245 |
+
def _init_weights(self, m):
|
246 |
+
if isinstance(m, nn.Linear):
|
247 |
+
trunc_normal_(m.weight, std=0.02)
|
248 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
249 |
+
nn.init.constant_(m.bias, 0)
|
250 |
+
elif isinstance(m, nn.LayerNorm):
|
251 |
+
nn.init.constant_(m.bias, 0)
|
252 |
+
nn.init.constant_(m.weight, 1.0)
|
253 |
+
|
254 |
+
def interpolate_pos_encoding(self, x, w, h):
|
255 |
+
npatch = x.shape[1] - 1
|
256 |
+
N = self.pos_embed.shape[1] - 1
|
257 |
+
if npatch == N and w == h:
|
258 |
+
return self.pos_embed
|
259 |
+
class_pos_embed = self.pos_embed[:, 0]
|
260 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
261 |
+
dim = x.shape[-1]
|
262 |
+
w0 = w // self.patch_embed.patch_size
|
263 |
+
h0 = h // self.patch_embed.patch_size
|
264 |
+
# we add a small number to avoid floating point error in the interpolation
|
265 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
266 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
267 |
+
patch_pos_embed = nn.functional.interpolate(
|
268 |
+
patch_pos_embed.reshape(
|
269 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
270 |
+
).permute(0, 3, 1, 2),
|
271 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
272 |
+
mode="bicubic",
|
273 |
+
)
|
274 |
+
assert (
|
275 |
+
int(w0) == patch_pos_embed.shape[-2]
|
276 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
277 |
+
)
|
278 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
279 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
280 |
+
|
281 |
+
def prepare_tokens(self, x):
|
282 |
+
B, nc, w, h = x.shape
|
283 |
+
x = self.patch_embed(x) # patch linear embedding
|
284 |
+
|
285 |
+
# add the [CLS] token to the embed patch tokens
|
286 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
287 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
288 |
+
|
289 |
+
# add positional encoding to each token
|
290 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
291 |
+
|
292 |
+
return self.pos_drop(x)
|
293 |
+
|
294 |
+
def forward(self, x):
|
295 |
+
x = self.prepare_tokens(x)
|
296 |
+
for blk in self.blocks:
|
297 |
+
x = blk(x)
|
298 |
+
x = self.norm(x)
|
299 |
+
return x[:, 0]
|
300 |
+
|
301 |
+
def get_last_selfattention(self, x):
|
302 |
+
x = self.prepare_tokens(x)
|
303 |
+
for i, blk in enumerate(self.blocks):
|
304 |
+
if i < len(self.blocks) - 1:
|
305 |
+
x = blk(x)
|
306 |
+
else:
|
307 |
+
# return attention of the last block
|
308 |
+
return blk(x, return_attention=True)
|
309 |
+
|
310 |
+
def get_intermediate_layers(self, x, n=1):
|
311 |
+
x = self.prepare_tokens(x)
|
312 |
+
# we return the output tokens from the `n` last blocks
|
313 |
+
output = []
|
314 |
+
for i, blk in enumerate(self.blocks):
|
315 |
+
x = blk(x)
|
316 |
+
if len(self.blocks) - i <= n:
|
317 |
+
output.append(self.norm(x))
|
318 |
+
return output
|
319 |
+
|
320 |
+
|
321 |
+
def vit_tiny(patch_size=16, **kwargs):
|
322 |
+
model = VisionTransformer(
|
323 |
+
patch_size=patch_size,
|
324 |
+
embed_dim=192,
|
325 |
+
depth=12,
|
326 |
+
num_heads=3,
|
327 |
+
mlp_ratio=4,
|
328 |
+
qkv_bias=True,
|
329 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
330 |
+
**kwargs
|
331 |
+
)
|
332 |
+
return model
|
333 |
+
|
334 |
+
|
335 |
+
def vit_small(patch_size=16, **kwargs):
|
336 |
+
model = VisionTransformer(
|
337 |
+
patch_size=patch_size,
|
338 |
+
embed_dim=384,
|
339 |
+
depth=12,
|
340 |
+
num_heads=6,
|
341 |
+
mlp_ratio=4,
|
342 |
+
qkv_bias=True,
|
343 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
344 |
+
**kwargs
|
345 |
+
)
|
346 |
+
return model
|
347 |
+
|
348 |
+
|
349 |
+
def vit_base(patch_size=16, **kwargs):
|
350 |
+
model = VisionTransformer(
|
351 |
+
patch_size=patch_size,
|
352 |
+
embed_dim=768,
|
353 |
+
depth=12,
|
354 |
+
num_heads=12,
|
355 |
+
mlp_ratio=4,
|
356 |
+
qkv_bias=True,
|
357 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
358 |
+
**kwargs
|
359 |
+
)
|
360 |
+
return model
|
361 |
+
|
362 |
+
|
363 |
+
class DINOHead(nn.Module):
|
364 |
+
def __init__(
|
365 |
+
self,
|
366 |
+
in_dim,
|
367 |
+
out_dim,
|
368 |
+
use_bn=False,
|
369 |
+
norm_last_layer=True,
|
370 |
+
nlayers=3,
|
371 |
+
hidden_dim=2048,
|
372 |
+
bottleneck_dim=256,
|
373 |
+
):
|
374 |
+
super().__init__()
|
375 |
+
nlayers = max(nlayers, 1)
|
376 |
+
if nlayers == 1:
|
377 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
378 |
+
else:
|
379 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
380 |
+
if use_bn:
|
381 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
382 |
+
layers.append(nn.GELU())
|
383 |
+
for _ in range(nlayers - 2):
|
384 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
385 |
+
if use_bn:
|
386 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
387 |
+
layers.append(nn.GELU())
|
388 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
389 |
+
self.mlp = nn.Sequential(*layers)
|
390 |
+
self.apply(self._init_weights)
|
391 |
+
self.last_layer = nn.utils.weight_norm(
|
392 |
+
nn.Linear(bottleneck_dim, out_dim, bias=False)
|
393 |
+
)
|
394 |
+
self.last_layer.weight_g.data.fill_(1)
|
395 |
+
if norm_last_layer:
|
396 |
+
self.last_layer.weight_g.requires_grad = False
|
397 |
+
|
398 |
+
def _init_weights(self, m):
|
399 |
+
if isinstance(m, nn.Linear):
|
400 |
+
trunc_normal_(m.weight, std=0.02)
|
401 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
402 |
+
nn.init.constant_(m.bias, 0)
|
403 |
+
|
404 |
+
def forward(self, x):
|
405 |
+
x = self.mlp(x)
|
406 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
407 |
+
x = self.last_layer(x)
|
408 |
+
return x
|
dino/visualize_attention.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import argparse
|
17 |
+
import cv2
|
18 |
+
import random
|
19 |
+
import colorsys
|
20 |
+
import requests
|
21 |
+
from io import BytesIO
|
22 |
+
|
23 |
+
import skimage.io
|
24 |
+
from skimage.measure import find_contours
|
25 |
+
import matplotlib.pyplot as plt
|
26 |
+
from matplotlib.patches import Polygon
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import torchvision
|
30 |
+
from torchvision import transforms as pth_transforms
|
31 |
+
import numpy as np
|
32 |
+
from PIL import Image
|
33 |
+
|
34 |
+
import utils
|
35 |
+
import vision_transformer as vits
|
36 |
+
|
37 |
+
|
38 |
+
def apply_mask(image, mask, color, alpha=0.5):
|
39 |
+
for c in range(3):
|
40 |
+
image[:, :, c] = (
|
41 |
+
image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
|
42 |
+
)
|
43 |
+
return image
|
44 |
+
|
45 |
+
|
46 |
+
def random_colors(N, bright=True):
|
47 |
+
"""
|
48 |
+
Generate random colors.
|
49 |
+
"""
|
50 |
+
brightness = 1.0 if bright else 0.7
|
51 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
52 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
53 |
+
random.shuffle(colors)
|
54 |
+
return colors
|
55 |
+
|
56 |
+
|
57 |
+
def display_instances(
|
58 |
+
image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5
|
59 |
+
):
|
60 |
+
fig = plt.figure(figsize=figsize, frameon=False)
|
61 |
+
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
|
62 |
+
ax.set_axis_off()
|
63 |
+
fig.add_axes(ax)
|
64 |
+
ax = plt.gca()
|
65 |
+
|
66 |
+
N = 1
|
67 |
+
mask = mask[None, :, :]
|
68 |
+
# Generate random colors
|
69 |
+
colors = random_colors(N)
|
70 |
+
|
71 |
+
# Show area outside image boundaries.
|
72 |
+
height, width = image.shape[:2]
|
73 |
+
margin = 0
|
74 |
+
ax.set_ylim(height + margin, -margin)
|
75 |
+
ax.set_xlim(-margin, width + margin)
|
76 |
+
ax.axis("off")
|
77 |
+
masked_image = image.astype(np.uint32).copy()
|
78 |
+
for i in range(N):
|
79 |
+
color = colors[i]
|
80 |
+
_mask = mask[i]
|
81 |
+
if blur:
|
82 |
+
_mask = cv2.blur(_mask, (10, 10))
|
83 |
+
# Mask
|
84 |
+
masked_image = apply_mask(masked_image, _mask, color, alpha)
|
85 |
+
# Mask Polygon
|
86 |
+
# Pad to ensure proper polygons for masks that touch image edges.
|
87 |
+
if contour:
|
88 |
+
padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
|
89 |
+
padded_mask[1:-1, 1:-1] = _mask
|
90 |
+
contours = find_contours(padded_mask, 0.5)
|
91 |
+
for verts in contours:
|
92 |
+
# Subtract the padding and flip (y, x) to (x, y)
|
93 |
+
verts = np.fliplr(verts) - 1
|
94 |
+
p = Polygon(verts, facecolor="none", edgecolor=color)
|
95 |
+
ax.add_patch(p)
|
96 |
+
ax.imshow(masked_image.astype(np.uint8), aspect="auto")
|
97 |
+
fig.savefig(fname)
|
98 |
+
print(f"{fname} saved.")
|
99 |
+
return
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
parser = argparse.ArgumentParser("Visualize Self-Attention maps")
|
104 |
+
parser.add_argument(
|
105 |
+
"--arch",
|
106 |
+
default="vit_small",
|
107 |
+
type=str,
|
108 |
+
choices=["vit_tiny", "vit_small", "vit_base"],
|
109 |
+
help="Architecture (support only ViT atm).",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--patch_size", default=8, type=int, help="Patch resolution of the model."
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--pretrained_weights",
|
116 |
+
default="",
|
117 |
+
type=str,
|
118 |
+
help="Path to pretrained weights to load.",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--checkpoint_key",
|
122 |
+
default="teacher",
|
123 |
+
type=str,
|
124 |
+
help='Key to use in the checkpoint (example: "teacher")',
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--image_path", default=None, type=str, help="Path of the image to load."
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--image_size", default=(480, 480), type=int, nargs="+", help="Resize image."
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--output_dir", default=".", help="Path where to save visualizations."
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--threshold",
|
137 |
+
type=float,
|
138 |
+
default=None,
|
139 |
+
help="""We visualize masks
|
140 |
+
obtained by thresholding the self-attention maps to keep xx% of the mass.""",
|
141 |
+
)
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
145 |
+
# build model
|
146 |
+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
147 |
+
for p in model.parameters():
|
148 |
+
p.requires_grad = False
|
149 |
+
model.eval()
|
150 |
+
model.to(device)
|
151 |
+
if os.path.isfile(args.pretrained_weights):
|
152 |
+
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
|
153 |
+
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
|
154 |
+
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
|
155 |
+
state_dict = state_dict[args.checkpoint_key]
|
156 |
+
# remove `module.` prefix
|
157 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
158 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
159 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
160 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
161 |
+
print(
|
162 |
+
"Pretrained weights found at {} and loaded with msg: {}".format(
|
163 |
+
args.pretrained_weights, msg
|
164 |
+
)
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
print(
|
168 |
+
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
|
169 |
+
)
|
170 |
+
url = None
|
171 |
+
if args.arch == "vit_small" and args.patch_size == 16:
|
172 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
173 |
+
elif args.arch == "vit_small" and args.patch_size == 8:
|
174 |
+
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
|
175 |
+
elif args.arch == "vit_base" and args.patch_size == 16:
|
176 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
177 |
+
elif args.arch == "vit_base" and args.patch_size == 8:
|
178 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
179 |
+
if url is not None:
|
180 |
+
print(
|
181 |
+
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
|
182 |
+
)
|
183 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
184 |
+
url="https://dl.fbaipublicfiles.com/dino/" + url
|
185 |
+
)
|
186 |
+
model.load_state_dict(state_dict, strict=True)
|
187 |
+
else:
|
188 |
+
print(
|
189 |
+
"There is no reference weights available for this model => We use random weights."
|
190 |
+
)
|
191 |
+
|
192 |
+
# open image
|
193 |
+
if args.image_path is None:
|
194 |
+
# user has not specified any image - we use our own image
|
195 |
+
print(
|
196 |
+
"Please use the `--image_path` argument to indicate the path of the image you wish to visualize."
|
197 |
+
)
|
198 |
+
print(
|
199 |
+
"Since no image path have been provided, we take the first image in our paper."
|
200 |
+
)
|
201 |
+
response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
|
202 |
+
img = Image.open(BytesIO(response.content))
|
203 |
+
img = img.convert("RGB")
|
204 |
+
elif os.path.isfile(args.image_path):
|
205 |
+
with open(args.image_path, "rb") as f:
|
206 |
+
img = Image.open(f)
|
207 |
+
img = img.convert("RGB")
|
208 |
+
else:
|
209 |
+
print(f"Provided image path {args.image_path} is non valid.")
|
210 |
+
sys.exit(1)
|
211 |
+
transform = pth_transforms.Compose(
|
212 |
+
[
|
213 |
+
pth_transforms.Resize(args.image_size),
|
214 |
+
pth_transforms.ToTensor(),
|
215 |
+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
216 |
+
]
|
217 |
+
)
|
218 |
+
img = transform(img)
|
219 |
+
|
220 |
+
# make the image divisible by the patch size
|
221 |
+
w, h = (
|
222 |
+
img.shape[1] - img.shape[1] % args.patch_size,
|
223 |
+
img.shape[2] - img.shape[2] % args.patch_size,
|
224 |
+
)
|
225 |
+
img = img[:, :w, :h].unsqueeze(0)
|
226 |
+
|
227 |
+
w_featmap = img.shape[-2] // args.patch_size
|
228 |
+
h_featmap = img.shape[-1] // args.patch_size
|
229 |
+
|
230 |
+
attentions = model.get_last_selfattention(img.to(device))
|
231 |
+
|
232 |
+
nh = attentions.shape[1] # number of head
|
233 |
+
|
234 |
+
# we keep only the output patch attention
|
235 |
+
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
236 |
+
|
237 |
+
if args.threshold is not None:
|
238 |
+
# we keep only a certain percentage of the mass
|
239 |
+
val, idx = torch.sort(attentions)
|
240 |
+
val /= torch.sum(val, dim=1, keepdim=True)
|
241 |
+
cumval = torch.cumsum(val, dim=1)
|
242 |
+
th_attn = cumval > (1 - args.threshold)
|
243 |
+
idx2 = torch.argsort(idx)
|
244 |
+
for head in range(nh):
|
245 |
+
th_attn[head] = th_attn[head][idx2[head]]
|
246 |
+
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
247 |
+
# interpolate
|
248 |
+
th_attn = (
|
249 |
+
nn.functional.interpolate(
|
250 |
+
th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
|
251 |
+
)[0]
|
252 |
+
.cpu()
|
253 |
+
.numpy()
|
254 |
+
)
|
255 |
+
|
256 |
+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
257 |
+
attentions = (
|
258 |
+
nn.functional.interpolate(
|
259 |
+
attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
|
260 |
+
)[0]
|
261 |
+
.cpu()
|
262 |
+
.numpy()
|
263 |
+
)
|
264 |
+
|
265 |
+
# save attentions heatmaps
|
266 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
267 |
+
torchvision.utils.save_image(
|
268 |
+
torchvision.utils.make_grid(img, normalize=True, scale_each=True),
|
269 |
+
os.path.join(args.output_dir, "img.png"),
|
270 |
+
)
|
271 |
+
for j in range(nh):
|
272 |
+
fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
|
273 |
+
plt.imsave(fname=fname, arr=attentions[j], format="png")
|
274 |
+
print(f"{fname} saved.")
|
275 |
+
|
276 |
+
if args.threshold is not None:
|
277 |
+
image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
|
278 |
+
for j in range(nh):
|
279 |
+
display_instances(
|
280 |
+
image,
|
281 |
+
th_attn[j],
|
282 |
+
fname=os.path.join(
|
283 |
+
args.output_dir,
|
284 |
+
"mask_th" + str(args.threshold) + "_head" + str(j) + ".png",
|
285 |
+
),
|
286 |
+
blur=False,
|
287 |
+
)
|