nzl-thu commited on
Commit
0683403
β€’
1 Parent(s): 35649f5

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +209 -0
README.md ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep Model Assembling
2
+
3
+ This repository contains the official code for [Deep Model Assembling](https://arxiv.org/abs/2212.04129).
4
+
5
+ <p align="center">
6
+ <img src="imgs/teaser.png" width= "450">
7
+ </p>
8
+
9
+ > **Title**:&emsp;&emsp;[**Deep Model Assembling**](https://arxiv.org/abs/2212.04129)
10
+ > **Authors**:&nbsp;&nbsp;[Zanlin Ni](https://scholar.google.com/citations?user=Yibz_asAAAAJ&hl=en&oi=ao), [Yulin Wang](https://scholar.google.com/citations?hl=en&user=gBP38gcAAAAJ), Jiangwei Yu, [Haojun Jiang](https://scholar.google.com/citations?hl=en&user=ULmStp8AAAAJ), [Yue Cao](https://scholar.google.com/citations?hl=en&user=iRUO1ckAAAAJ), [Gao Huang](https://scholar.google.com/citations?user=-P9LwcgAAAAJ&hl=en&oi=ao) (Corresponding Author)
11
+ > **Institute**: Tsinghua University and Beijing Academy of Artificial Intelligence (BAAI)
12
+ > **Publish**:&nbsp;&nbsp;&nbsp;*arXiv preprint ([arXiv 2212.04129](https://arxiv.org/abs/2212.04129))*
13
+ > **Contact**:&nbsp;&nbsp;nzl22 at mails dot tsinghua dot edu dot cn
14
+
15
+ ## News
16
+
17
+ - `Dec 10, 2022`: release code for training ViT-B, ViT-L and ViT-H on ImageNet-1K.
18
+
19
+ ## Overview
20
+
21
+ In this paper, we present a divide-and-conquer strategy for training large models. Our algorithm, Model Assembling, divides a large model into smaller modules, optimizes them independently, and then assembles them together. Though conceptually simple, our method significantly outperforms end-to-end (E2E) training in terms of both training efficiency and final accuracy. For example, on ViT-H, Model Assembling outperforms E2E training by **2.7%**, while reducing the training cost by **43%**.
22
+
23
+ <p align="center">
24
+ <img src="imgs/ours.png" width= "900">
25
+ </p>
26
+
27
+ ## Data Preparation
28
+
29
+ - The ImageNet dataset should be prepared as follows:
30
+
31
+ ```
32
+ data
33
+ β”œβ”€β”€ train
34
+ β”‚ β”œβ”€β”€ folder 1 (class 1)
35
+ β”‚ β”œβ”€β”€ folder 2 (class 1)
36
+ β”‚ β”œβ”€β”€ ...
37
+ β”œβ”€β”€ val
38
+ β”‚ β”œβ”€β”€ folder 1 (class 1)
39
+ β”‚ β”œβ”€β”€ folder 2 (class 1)
40
+ β”‚ β”œβ”€β”€ ...
41
+
42
+ ```
43
+
44
+ ## Training on ImageNet-1K
45
+
46
+ - You can add `--use_amp 1` to train in PyTorch's Automatic Mixed Precision (AMP).
47
+ - Auto-resuming is enabled by default, i.e., the training script will automatically resume from the latest ckpt in <code>output_dir</code>.
48
+ - The effective batch size = `NGPUS` * `batch_size` * `update_freq`. We keep using an effective batch size of 2048. To avoid OOM issues, you may adjust these arguments accordingly.
49
+ - We provide single-node training scripts for simplicity. For multi-node training, simply modify the training scripts accordingly with torchrun:
50
+
51
+ ```bash
52
+ python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port=23346 --use_env main.py ...
53
+
54
+ # modify the above code to
55
+
56
+ torchrun \
57
+ --nnodes=$NODES \
58
+ --nproc_per_node=$NGPUS \
59
+ --rdzv_backend=c10d \
60
+ --rdzv_endpoint=$MASTER_ADDR:60900 \
61
+ main.py ...
62
+ ```
63
+
64
+ <details>
65
+ <summary>Pre-training meta models (click to expand).</summary>
66
+
67
+ ```bash
68
+ PHASE=PT # Pre-training
69
+ MODEL=base # for base
70
+ # MODEL=large # for large
71
+ # MODEL=huge # for huge
72
+ NGPUS=8
73
+
74
+ args=(
75
+ --phase ${PHASE}
76
+ --model vit_${MODEL}_patch16_224 # for base, large
77
+ # --model vit_${MODEL}_patch14_224 # for huge
78
+ --divided_depths 1 1 1 1
79
+ --output_dir ./log_dir/${PHASE}/${MODEL}
80
+
81
+ --batch_size 256
82
+ --epochs 300
83
+ --drop-path 0
84
+ )
85
+
86
+ python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port=23346 --use_env main.py "${args[@]}"
87
+ ```
88
+
89
+ </details>
90
+
91
+ <details>
92
+ <summary>Modular training (click to expand).</summary>
93
+
94
+ ```bash
95
+ PHASE=MT # Modular Training
96
+ MODEL=base DEPTH=12 # for base
97
+ # MODEL=large DEPTH=24 # for large
98
+ # MODEL=huge DEPTH=32 # for huge
99
+ NGPUS=8
100
+
101
+ args=(
102
+ --phase ${PHASE}
103
+ --model vit_${MODEL}_patch16_224 # for base, large
104
+ # --model vit_${MODEL}_patch14_224 # for huge
105
+ --meta_model ./log_dir/PT_${MODEL}/finished_checkpoint.pth # loading the pre-trained meta model
106
+
107
+ --batch_size 128
108
+ --update_freq 2
109
+ --epochs 100
110
+ --drop-path 0.1
111
+ )
112
+
113
+ # Modular training each target module. Each line can be executed in parallel.
114
+ python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port=23346 --use_env main.py "${args[@]}" --idx 0 --divided_depths $((DEPTH/4)) 1 1 1 --output_dir ./log_dir/${PHASE}_${MODEL}_0
115
+ python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port=23346 --use_env main.py "${args[@]}" --idx 1 --divided_depths 1 $((DEPTH/4)) 1 1 --output_dir ./log_dir/${PHASE}_${MODEL}_1
116
+ python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port=23346 --use_env main.py "${args[@]}" --idx 2 --divided_depths 1 1 $((DEPTH/4)) 1 --output_dir ./log_dir/${PHASE}_${MODEL}_2
117
+ python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port=23346 --use_env main.py "${args[@]}" --idx 3 --divided_depths 1 1 1 $((DEPTH/4)) --output_dir ./log_dir/${PHASE}_${MODEL}_3
118
+
119
+ ```
120
+
121
+ </details>
122
+
123
+ <details>
124
+ <summary>Assemble & Fine-tuning (click to expand).</summary>
125
+
126
+ ```bash
127
+ PHASE=FT # Assemble & Fine-tuning
128
+ MODEL=base DEPTH=12 # for base
129
+ # MODEL=large DEPTH=24 # for large
130
+ # MODEL=huge DEPTH=32 # for huge
131
+ NGPUS=8
132
+
133
+ args=(
134
+ --phase ${PHASE}
135
+ --model vit_${MODEL}_patch16_224 # for base, large
136
+ # --model vit_${MODEL}_patch14_224 # for huge
137
+ --incubation_models ./log_dir/MT_${MODEL}_*/finished_checkpoint.pth # for assembling
138
+ --divided_depths $((DEPTH/4)) $((DEPTH/4)) $((DEPTH/4)) $((DEPTH/4)) \
139
+ --output_dir ./log_dir/${PHASE}_${MODEL}
140
+
141
+ --batch_size 64
142
+ --update_freq 4
143
+ --epochs 100
144
+ --warmup-epochs 0
145
+ --clip-grad 1
146
+ --drop-path 0.1 # for base
147
+ # --drop-path 0.5 # for large
148
+ # --drop-path 0.6 # for huge
149
+ )
150
+
151
+ python -m torch.distributed.launch --nproc_per_node=${NGPUS} --master_port=23346 --use_env main.py "${args[@]}"
152
+ ```
153
+
154
+ </details>
155
+
156
+ ## Results
157
+
158
+ ### Results on ImageNet-1K
159
+
160
+ <p align="center">
161
+ <img src="./imgs/in1k.png" width= "900">
162
+ </p>
163
+
164
+ ### Results on CIFAR-100
165
+
166
+ <p align="center">
167
+ <img src="./imgs/cifar.png" width= "900">
168
+ </p>
169
+
170
+ ### Training Efficiency
171
+
172
+ - Comparing different training budgets
173
+
174
+ <p align="center">
175
+ <img src="./imgs/efficiency.png" width= "900">
176
+ </p>
177
+
178
+ - Detailed convergence curves of ViT-Huge
179
+
180
+ <p align="center">
181
+ <img src="./imgs/huge_curve.png" width= "450">
182
+ </p>
183
+
184
+ ### Data Efficiency
185
+
186
+ <p align="center">
187
+ <img src="./imgs/data_efficiency.png" width= "450">
188
+ </p>
189
+
190
+ ## Citation
191
+
192
+ If you find our work helpful, please **star🌟** this repo and **citeπŸ“‘** our paper. Thanks for your support!
193
+
194
+ ```
195
+ @article{Ni2022Assemb,
196
+ title={Deep Model Assembling},
197
+ author={Ni, Zanlin and Wang, Yulin and Yu, Jiangwei and Jiang, Haojun and Cao, Yue and Huang, Gao},
198
+ journal={arXiv preprint arXiv:2212.04129},
199
+ year={2022}
200
+ }
201
+ ```
202
+
203
+ ## Acknowledgements
204
+
205
+ Our implementation is mainly based on [deit](https://github.com/facebookresearch/deit). We thank to their clean codebase.
206
+
207
+ ## Contact
208
+
209
+ If you have any questions or concerns, please send mail to [nzl22@mails.tsinghua.edu.cn](mailto:nzl22@mails.tsinghua.edu.cn).