Upload 10 files
Browse files- README.md +185 -3
- create_acc_rob_pred.py +176 -0
- create_acc_rob_pred_dataset.py +75 -0
- eval_ofa_net.py +94 -0
- hugging_face.py +21 -0
- sample_eval.py +89 -0
- search_best.py +273 -0
- train_ofa_net.py +558 -0
- train_ofa_net_WPS.py +572 -0
- train_teacher_net.py +216 -0
README.md
CHANGED
|
@@ -1,3 +1,185 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">
|
| 2 |
+
<img src="images/ProARD_logo.png" width="500"/>
|
| 3 |
+
<br/>
|
| 4 |
+
PROARD: PROGRESSIVE ADVERSARIAL ROBUSTNESS DISTILLATION: PROVIDE WIDE RANGE OF ROBUST STUDENTS
|
| 5 |
+
</br>
|
| 6 |
+
</h1>
|
| 7 |
+
<p align="center">
|
| 8 |
+
<a href="#background">Background</a> •
|
| 9 |
+
<a href="#usage">Usage</a> •
|
| 10 |
+
<a href="#code">Code</a> •
|
| 11 |
+
<a href="#citation">Citation</a> •
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
## Background
|
| 15 |
+
Progressive Adversarial Robustness Distillation (ProARD), enabling the efficient
|
| 16 |
+
one-time training of a dynamic network that supports a diverse range of accurate and robust student
|
| 17 |
+
networks without requiring retraining. ProARD makes a dynamic deep neural network based on
|
| 18 |
+
dynamic layers by encompassing variations in width, depth, and expansion in each design stage to
|
| 19 |
+
support a wide range of architectures.
|
| 20 |
+
|
| 21 |
+
<h1 align="center">
|
| 22 |
+
<img src="images/ProARD.png" width="1000"/>
|
| 23 |
+
</h1>
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## Usage
|
| 28 |
+
```
|
| 29 |
+
git clone https://github.com/hamidmousavi0/ProARD.git
|
| 30 |
+
```
|
| 31 |
+
## Code Structure
|
| 32 |
+
```
|
| 33 |
+
- attacks/ # Different Adversarial attack methods (PGD, AutoAttack, FGSM, DeepFool, etc. ([Refrence](https://github.com/imrahulr/hat.git)))
|
| 34 |
+
- proard/
|
| 35 |
+
- classification/
|
| 36 |
+
- data_provider/ # The dataset and dataloader definitions for Cifar-10, Cifar-100, and ImageNet.
|
| 37 |
+
- elastic_nn/
|
| 38 |
+
- modules/ # The deficnition of dynamic layers
|
| 39 |
+
- networks/ # The deficnition of dynamic networks
|
| 40 |
+
- training/ # Progressive training
|
| 41 |
+
-networks/ # The original networks
|
| 42 |
+
-run_anager/ # The Configs and distributed training
|
| 43 |
+
- nas
|
| 44 |
+
- accuracy_predictor/ # The accuracy and robustness predictor
|
| 45 |
+
- efficiency_predictor/ # The efficiency predictor
|
| 46 |
+
- search_algorithm/ # The Multi-Objective Search Engine
|
| 47 |
+
- utils/ # Utility functions
|
| 48 |
+
- model_zoo.py # All the models for evaluation
|
| 49 |
+
- create_acc_rob_pred_dataset.py # Create dataset to train the accuracy-robustness predictor.
|
| 50 |
+
- create_acc_rob_pred.py # make the predictor model.
|
| 51 |
+
- eval_ofa_net.py # Eval the sub-nets
|
| 52 |
+
- search_best.py # Search the best sub-net
|
| 53 |
+
- train_ofa_net_WPS.py # train the dynamic network without progressive training.
|
| 54 |
+
- train_ofa_net.py # Train the dynamic network with progressive training.
|
| 55 |
+
- train_teacher_net.py # Train teacher network for Robust knoweldge distillation.
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
### Installing
|
| 59 |
+
|
| 60 |
+
**From Source**
|
| 61 |
+
|
| 62 |
+
Download this repository into your project folder.
|
| 63 |
+
|
| 64 |
+
### Details of the usage
|
| 65 |
+
|
| 66 |
+
## Evaluation
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
python eval_ofa_net.py --path path of dataset --net Dynamic net name (ResNet50, MBV3)
|
| 70 |
+
--dataset (cifar10, cifar100) --robust_mode (True, False)
|
| 71 |
+
--WPS (True, False) --attack ('fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce')
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
## Training
|
| 76 |
+
|
| 77 |
+
### Step-0: Train Teacher Net
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
horovodrun -np 4 python train_teacher_net.py --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
|
| 81 |
+
--robust_mode (True, False) --epsilon 0.031 --num_steps 10
|
| 82 |
+
--step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
|
| 83 |
+
--attack_type 'linf-pgd'
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Step-1: Dynamic Width/Kernel training
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
horovodrun -np 4 python train_ofa_net.py --task 'width' or 'kernel' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
|
| 91 |
+
--robust_mode (True, False) --epsilon 0.031 --num_steps 10
|
| 92 |
+
--step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
|
| 93 |
+
--attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 1
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Step-2: Dynamic Width/Kernel and depth training
|
| 97 |
+
|
| 98 |
+
##### Phase-1
|
| 99 |
+
```
|
| 100 |
+
horovodrun -np 4 python train_ofa_net.py --task 'depth' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
|
| 101 |
+
--robust_mode (True, False) --epsilon 0.031 --num_steps 10
|
| 102 |
+
--step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
|
| 103 |
+
--attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 1
|
| 104 |
+
```
|
| 105 |
+
##### Phase-2
|
| 106 |
+
```
|
| 107 |
+
horovodrun -np 4 python train_ofa_net.py --task 'depth' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
|
| 108 |
+
--robust_mode (True, False) --epsilon 0.031 --num_steps 10
|
| 109 |
+
--step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
|
| 110 |
+
--attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 2
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### Step-3: Dynamic Width/Kernel, depth, and expand training
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
##### Phase-1
|
| 117 |
+
```
|
| 118 |
+
horovodrun -np 4 python train_ofa_net.py --task 'expand' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
|
| 119 |
+
--robust_mode (True, False) --epsilon 0.031 --num_steps 10
|
| 120 |
+
--step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
|
| 121 |
+
--attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 1
|
| 122 |
+
```
|
| 123 |
+
##### Phase-2
|
| 124 |
+
```
|
| 125 |
+
horovodrun -np 4 python train_ofa_net.py --task 'expand' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
|
| 126 |
+
--robust_mode (True, False) --epsilon 0.031 --num_steps 10
|
| 127 |
+
--step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
|
| 128 |
+
--attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 2
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
<!--
|
| 136 |
+
|
| 137 |
+
* **ProAct** (the proposed algorithm) ([paper](https://arxiv.org/abs/2406.06313) and ([code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/proact.py)).
|
| 138 |
+
* **FitAct** ([paper](https://arxiv.org/pdf/2112.13544) and [code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/fitact.py)).
|
| 139 |
+
* **FtClipAct** ([paper](https://arxiv.org/pdf/1912.00941) and [code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/ftclip.py)).
|
| 140 |
+
* **Ranger** ([paper](https://arxiv.org/pdf/2003.13874) and [code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/ranger.py)).
|
| 141 |
+
-->
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
<!-- Use the following notebook to learn the main steps of the tool.
|
| 147 |
+
[](https://github.com/hamidmousavi0/reliable-relu-toolbox/blob/master/RReLU.ipynb)-->
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
## To-do list
|
| 151 |
+
- [ ] Add object detection Task
|
| 152 |
+
- [ ] Add Transformers architectures
|
| 153 |
+
|
| 154 |
+
<!--
|
| 155 |
+
### Run search in the command line
|
| 156 |
+
|
| 157 |
+
When you download this repository into your project folder.
|
| 158 |
+
```
|
| 159 |
+
torchrun --nproc_per_node=2 search.py --dataset cifar10 --data_path "./dataset/cifar10" --batch_size 128 --model "resnet20" --n_worker 32 \
|
| 160 |
+
--name_relu_bound "zero" --name_serach_bound "ranger" --bounds_type "layer" --bitflip "fixed" --image_size 32 --pretrained_model
|
| 161 |
+
```
|
| 162 |
+
-->
|
| 163 |
+
## Citation
|
| 164 |
+
|
| 165 |
+
View the [published paper(preprint), Accepted in IJCNN 2025](https://www.arxiv.org/pdf/2506.07666).
|
| 166 |
+
<!--
|
| 167 |
+
```
|
| 168 |
+
@article{mousavi2024proact,
|
| 169 |
+
title={ProAct: Progressive Training for Hybrid Clipped Activation Function to Enhance Resilience of DNNs},
|
| 170 |
+
author={Mousavi, Seyedhamidreza and Ahmadilivani, Mohammad Hasan and Raik, Jaan and Jenihhin, Maksim and Daneshtalab, Masoud},
|
| 171 |
+
journal={arXiv preprint arXiv:2406.06313},
|
| 172 |
+
year={2024}
|
| 173 |
+
}
|
| 174 |
+
```
|
| 175 |
+
-->
|
| 176 |
+
## Acknowledgment
|
| 177 |
+
|
| 178 |
+
We acknowledge the National Academic Infrastructure for Supercomputing in Sweden (NAISS), partially funded by the Swedish Research Council through grant agreement no
|
| 179 |
+
|
| 180 |
+
## Contributors
|
| 181 |
+
Some of the code in this repository is based on the following amazing works:
|
| 182 |
+
|
| 183 |
+
[Once-For-All](https://github.com/mit-han-lab/once-for-all.git)
|
| 184 |
+
|
| 185 |
+
[Hat](https://github.com/imrahulr/hat.git)
|
create_acc_rob_pred.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.optim import *
|
| 11 |
+
from torch.optim.lr_scheduler import *
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from torchprofile import profile_macs
|
| 14 |
+
from torchvision.datasets import *
|
| 15 |
+
from torchvision.transforms import *
|
| 16 |
+
from proard.classification.data_providers.imagenet import ImagenetDataProvider
|
| 17 |
+
from proard.classification.run_manager import DistributedClassificationRunConfig, DistributedRunManager
|
| 18 |
+
from proard.model_zoo import DYN_net
|
| 19 |
+
from proard.nas.accuracy_predictor import AccuracyDataset,AccuracyPredictor,ResNetArchEncoder,RobustnessPredictor,MobileNetArchEncoder,AccuracyRobustnessDataset,Accuracy_Robustness_Predictor
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def RMSELoss(yhat,y):
|
| 24 |
+
return torch.sqrt(torch.mean((yhat-y)**2))
|
| 25 |
+
def train(
|
| 26 |
+
model: nn.Module,
|
| 27 |
+
dataloader: DataLoader,
|
| 28 |
+
criterion: nn.Module,
|
| 29 |
+
optimizer: Optimizer,
|
| 30 |
+
callbacks = None,
|
| 31 |
+
epochs = 10,
|
| 32 |
+
save_path = None
|
| 33 |
+
) -> None:
|
| 34 |
+
model.cuda()
|
| 35 |
+
model.train()
|
| 36 |
+
for epoch in range(epochs):
|
| 37 |
+
print(epoch)
|
| 38 |
+
for inputs, targets_acc, targets_rob in tqdm(dataloader, desc='train', leave=False):
|
| 39 |
+
inputs = inputs.float().cuda()
|
| 40 |
+
targets_acc = targets_acc.cuda()
|
| 41 |
+
targets_rob = targets_rob.cuda()
|
| 42 |
+
|
| 43 |
+
# Reset the gradients (from the last iteration)
|
| 44 |
+
optimizer.zero_grad()
|
| 45 |
+
|
| 46 |
+
# Forward inference
|
| 47 |
+
outputs = model(inputs)
|
| 48 |
+
loss = criterion(outputs[:,0], targets_acc) + criterion(outputs[:,1], targets_rob)
|
| 49 |
+
|
| 50 |
+
# Backward propagation
|
| 51 |
+
loss.backward()
|
| 52 |
+
|
| 53 |
+
# Update optimizer and LR scheduler
|
| 54 |
+
optimizer.step()
|
| 55 |
+
# scheduler.step(epoch)
|
| 56 |
+
|
| 57 |
+
if callbacks is not None:
|
| 58 |
+
for callback in callbacks:
|
| 59 |
+
callback()
|
| 60 |
+
torch.save(model.state_dict(), save_path)
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
@torch.inference_mode()
|
| 64 |
+
def evaluate(
|
| 65 |
+
model: nn.Module,
|
| 66 |
+
dataloader: DataLoader,
|
| 67 |
+
) -> float:
|
| 68 |
+
model.eval()
|
| 69 |
+
|
| 70 |
+
for inputs, targets_acc, targets_rob in tqdm(dataloader, desc="eval", leave=False):
|
| 71 |
+
# Move the data from CPU to GPU
|
| 72 |
+
inputs = inputs.cuda()
|
| 73 |
+
|
| 74 |
+
targets_acc = targets_acc.cuda()
|
| 75 |
+
targets_rob = targets_rob.cuda()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Inference
|
| 79 |
+
outputs = model(inputs)
|
| 80 |
+
|
| 81 |
+
# Convert logits to class indices
|
| 82 |
+
print(RMSELoss(outputs[:,0],targets_acc),RMSELoss(outputs[:,1],targets_rob))
|
| 83 |
+
return RMSELoss(outputs[:,0],targets_acc) + RMSELoss(outputs[:,1],targets_rob)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_model_flops(model, inputs):
|
| 87 |
+
num_macs = profile_macs(model, inputs)
|
| 88 |
+
return num_macs
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_model_size(model: nn.Module, data_width=32):
|
| 92 |
+
"""
|
| 93 |
+
calculate the model size in bits
|
| 94 |
+
:param data_width: #bits per element
|
| 95 |
+
"""
|
| 96 |
+
num_elements = 0
|
| 97 |
+
for param in model.parameters():
|
| 98 |
+
num_elements += param.numel()
|
| 99 |
+
return num_elements * data_width
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"-p", "--path", help="The path of cifar10", type=str, default="/dataset/cifar10"
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"-b",
|
| 111 |
+
"--batch_size",
|
| 112 |
+
help="The batch on every device for validation",
|
| 113 |
+
type=int,
|
| 114 |
+
default=32,
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"-n",
|
| 119 |
+
"--net",
|
| 120 |
+
metavar="DYNNET",
|
| 121 |
+
default="ResNet50",
|
| 122 |
+
choices=[
|
| 123 |
+
"ResNet50",
|
| 124 |
+
"MBV3",
|
| 125 |
+
"ProxylessNASNet",
|
| 126 |
+
],
|
| 127 |
+
help="Dyanmic networks",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--robust_mode", type=bool, default=True
|
| 135 |
+
)
|
| 136 |
+
args = parser.parse_args()
|
| 137 |
+
if args.net == "ResNet50":
|
| 138 |
+
arch = ResNetArchEncoder(image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[0,1,2],expand_list=[0.2,0.25,0.35],width_mult_list=[0.65,0.8,1.0])
|
| 139 |
+
else:
|
| 140 |
+
arch = MobileNetArchEncoder (image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[2,3,4],expand_list=[3,4,6],ks_list=[3,5,7])
|
| 141 |
+
print(arch)
|
| 142 |
+
acc_data = AccuracyRobustnessDataset("./acc_rob_data_{}_{}_{}".format(args.dataset,args.net,args.train_criterion))
|
| 143 |
+
train_loader, valid_loader, base_acc ,base_rob = acc_data.build_acc_data_loader(arch)
|
| 144 |
+
acc_pred_network = Accuracy_Robustness_Predictor(arch_encoder=arch,base_acc_val=None)
|
| 145 |
+
# optimizer_ = torch.optim.Adam(acc_pred_network.parameters(),lr=1e-3,weight_decay=1e-4)
|
| 146 |
+
# criterion = nn.MSELoss()
|
| 147 |
+
# acc_pred_network = train(acc_pred_network,train_loader,criterion,optimizer_,callbacks=None, epochs=50,save_path ="./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion).format(args.dataset))
|
| 148 |
+
acc_pred_network.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
|
| 149 |
+
print(evaluate(acc_pred_network,valid_loader))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# import numpy as np
|
| 154 |
+
# accs=[]
|
| 155 |
+
# robs=[]
|
| 156 |
+
# pred_accs=[]
|
| 157 |
+
# pred_robs=[]
|
| 158 |
+
# for x,acc,rob, in valid_loader:
|
| 159 |
+
# for ac in acc:
|
| 160 |
+
# accs.append(ac.item()*100)
|
| 161 |
+
# for ro in rob:
|
| 162 |
+
# robs.append(ro.item()*100)
|
| 163 |
+
|
| 164 |
+
# for x,acc,rob, in valid_loader:
|
| 165 |
+
# for arch in x:
|
| 166 |
+
# acc ,rob = acc_pred_network(arch.cuda())
|
| 167 |
+
# pred_accs.append(acc.item()*100)
|
| 168 |
+
# pred_robs.append(rob.item()*100)
|
| 169 |
+
# print(accs,robs)
|
| 170 |
+
# print(pred_accs,pred_robs)
|
| 171 |
+
# np.savetxt("./results/accs.csv", np.array(accs), delimiter=",")
|
| 172 |
+
# np.savetxt("./results/robs.csv", np.array(robs), delimiter=",")
|
| 173 |
+
# np.savetxt("./results/pred_accs.csv", np.array(pred_accs), delimiter=",")
|
| 174 |
+
# np.savetxt("./results/pred_robs.csv", np.array(pred_robs), delimiter=",")
|
| 175 |
+
|
| 176 |
+
|
create_acc_rob_pred_dataset.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from proard.classification.data_providers.imagenet import ImagenetDataProvider
|
| 6 |
+
from proard.classification.run_manager import DistributedClassificationRunConfig, DistributedRunManager
|
| 7 |
+
from proard.model_zoo import DYN_net
|
| 8 |
+
from proard.nas.accuracy_predictor import AccuracyRobustnessDataset
|
| 9 |
+
import horovod.torch as hvd
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument(
|
| 12 |
+
"-p", "--path", help="The path of cifar10", type=str, default="/dataset/cifar10"
|
| 13 |
+
)
|
| 14 |
+
parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"-b",
|
| 17 |
+
"--batch_size",
|
| 18 |
+
help="The batch on every device for validation",
|
| 19 |
+
type=int,
|
| 20 |
+
default=32,
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"-n",
|
| 25 |
+
"--net",
|
| 26 |
+
metavar="DYNNET",
|
| 27 |
+
default="ResNet50",
|
| 28 |
+
choices=[
|
| 29 |
+
"ResNet50",
|
| 30 |
+
"MBV3",
|
| 31 |
+
"ProxylessNASNet",
|
| 32 |
+
"MBV2"
|
| 33 |
+
],
|
| 34 |
+
help="Dynamic networks",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--robust_mode", type=bool, default=True
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--WPS", type=bool, default=True
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--base", type=bool, default=False
|
| 48 |
+
)
|
| 49 |
+
# Initialize Horovod
|
| 50 |
+
hvd.init()
|
| 51 |
+
# Pin GPU to be used to process local rank (one GPU per process)
|
| 52 |
+
torch.cuda.set_device(hvd.local_rank())
|
| 53 |
+
num_gpus = hvd.size()
|
| 54 |
+
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
if args.gpu == "all":
|
| 57 |
+
device_list = range(torch.cuda.device_count())
|
| 58 |
+
args.gpu = ",".join(str(_) for _ in device_list)
|
| 59 |
+
else:
|
| 60 |
+
device_list = [int(_) for _ in args.gpu.split(",")]
|
| 61 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
| 62 |
+
args.test_batch_size = args.batch_size # * max(len(device_list), 1)
|
| 63 |
+
ImagenetDataProvider.DEFAULT_PATH = args.path
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
distributed_run_config = DistributedClassificationRunConfig(**args.__dict__, num_replicas=num_gpus, rank=hvd.rank())
|
| 67 |
+
dyn_network = DYN_net(args.net, args.robust_mode , args.dataset, args.train_criterion, pretrained=True,run_config=distributed_run_config,WPS=args.WPS)
|
| 68 |
+
compression = hvd.Compression.none
|
| 69 |
+
distributed_run_manager = DistributedRunManager(".tmp/eval_subnet", dyn_network, distributed_run_config,compression,is_root=(hvd.rank() == 0),init=False)
|
| 70 |
+
distributed_run_manager.save_config()
|
| 71 |
+
# hvd broadcast
|
| 72 |
+
distributed_run_manager.broadcast()
|
| 73 |
+
acc_data = AccuracyRobustnessDataset("./acc_rob_data_WPS_{}_{}_{}".format(args.dataset,args.net,args.train_criterion))
|
| 74 |
+
|
| 75 |
+
acc_data.build_acc_rob_dataset(distributed_run_manager,dyn_network,image_size_list=[224 if args.dataset == "imagenet" else 32])
|
eval_ofa_net.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
from proard.classification.data_providers.imagenet import ImagenetDataProvider
|
| 10 |
+
from proard.classification.data_providers.cifar10 import Cifar10DataProvider
|
| 11 |
+
from proard.classification.data_providers.cifar100 import Cifar100DataProvider
|
| 12 |
+
from proard.classification.run_manager import ClassificationRunConfig, RunManager
|
| 13 |
+
from proard.model_zoo import DYN_net
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"-p", "--path", help="The path of imagenet", type=str, default="/dataset/imagenet"
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"-b",
|
| 23 |
+
"--batch-size",
|
| 24 |
+
help="The batch on every device for validation",
|
| 25 |
+
type=int,
|
| 26 |
+
default=16,
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"-n",
|
| 31 |
+
"--net",
|
| 32 |
+
metavar="DYNET",
|
| 33 |
+
default="ResNet50",
|
| 34 |
+
choices=[
|
| 35 |
+
"ResNet50",
|
| 36 |
+
"MBV3",
|
| 37 |
+
"ProxylessNASNet",
|
| 38 |
+
"MBV2",
|
| 39 |
+
"WideResNet"
|
| 40 |
+
],
|
| 41 |
+
help="dynamic networks",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--attack", type=str, default="autoattack" ,choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce']
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--robust_mode", type=bool, default=True
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--WPS", type=bool, default=False
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--base", type=bool, default=False
|
| 58 |
+
)
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
if args.gpu == "all":
|
| 61 |
+
device_list = range(torch.cuda.device_count())
|
| 62 |
+
args.gpu = ",".join(str(_) for _ in device_list)
|
| 63 |
+
else:
|
| 64 |
+
device_list = [int(_) for _ in args.gpu.split(",")]
|
| 65 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
| 66 |
+
args.batch_size = args.batch_size * max(len(device_list), 1)
|
| 67 |
+
ImagenetDataProvider.DEFAULT_PATH = args.path
|
| 68 |
+
|
| 69 |
+
run_config = ClassificationRunConfig(attack_type=args.attack,dataset= args.dataset, test_batch_size=args.batch_size, n_worker=args.workers,robust_mode=args.robust_mode)
|
| 70 |
+
dyn_network = DYN_net(args.net,args.robust_mode,args.dataset, args.train_criterion ,pretrained=True,run_config=run_config,WPS=args.WPS,base=args.base)
|
| 71 |
+
""" Randomly sample a sub-network,
|
| 72 |
+
you can also manually set the sub-network using:
|
| 73 |
+
dyn_network.set_active_subnet(ks=7, e=6, d=4)
|
| 74 |
+
"""
|
| 75 |
+
if not args.base:
|
| 76 |
+
# dyn_network.set_active_subnet(ks=3, e=4, d=2)
|
| 77 |
+
dyn_network.set_active_subnet(d=2,e=0.35,w=1.0)
|
| 78 |
+
# dyn_network.sample_active_subnet()
|
| 79 |
+
# dyn_network.set_max_net()
|
| 80 |
+
subnet = dyn_network.get_active_subnet(preserve_weight=True)
|
| 81 |
+
# print(subnet)
|
| 82 |
+
else:
|
| 83 |
+
subnet = dyn_network
|
| 84 |
+
""" Test sampled subnet
|
| 85 |
+
"""
|
| 86 |
+
run_manager = RunManager(".tmp/eval_subnet", subnet, run_config, init=False)
|
| 87 |
+
run_config.data_provider.assign_active_img_size(32)
|
| 88 |
+
run_manager.reset_running_statistics(net=subnet)
|
| 89 |
+
|
| 90 |
+
print("Test random subnet:")
|
| 91 |
+
# print(subnet.module_str)
|
| 92 |
+
|
| 93 |
+
loss, (top1, top5,robust1,robust5) = run_manager.validate(net=subnet,is_test=True)
|
| 94 |
+
print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f,\t robust1=%.1f,\t robust5=%.1f" % (loss, top1, top5,robust1,robust5))
|
hugging_face.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import interpreter_login
|
| 2 |
+
from huggingface_hub import upload_folder, delete_folder, upload_file
|
| 3 |
+
# interpreter_login()
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# upload_folder(folder_path = "attacks/",path_in_repo="attacks", repo_id="smi08/ProArd")
|
| 8 |
+
# upload_folder(folder_path = "images/",path_in_repo="images", repo_id="smi08/ProArd")
|
| 9 |
+
# upload_folder(folder_path = "proard/",path_in_repo="proard", repo_id="smi08/ProArd")
|
| 10 |
+
# upload_folder(folder_path = "robust_loss/",path_in_repo="robust_loss", repo_id="smi08/ProArd")
|
| 11 |
+
# upload_folder(folder_path = "utils/",path_in_repo="utils", repo_id="smi08/ProArd")
|
| 12 |
+
# delete_folder(path_in_repo="smi08", repo_id="smi08/ProArd")
|
| 13 |
+
upload_file(path_or_fileobj="create_acc_rob_pred_dataset.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 14 |
+
upload_file(path_or_fileobj="create_acc_rob_pred.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 15 |
+
upload_file(path_or_fileobj="eval_ofa_net.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 16 |
+
upload_file(path_or_fileobj="sample_eval.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 17 |
+
upload_file(path_or_fileobj="search_best.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 18 |
+
upload_file(path_or_fileobj="train_ofa_net_WPS.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 19 |
+
upload_file(path_or_fileobj="train_ofa_net.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 20 |
+
upload_file(path_or_fileobj="train_teacher_net.py",path_in_repo="",repo_id="smi08/ProArd")
|
| 21 |
+
upload_file(path_or_fileobj="README.md",path_in_repo="",repo_id="smi08/ProArd")
|
sample_eval.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import argparse
|
| 8 |
+
import sys
|
| 9 |
+
from proard.classification.data_providers.imagenet import ImagenetDataProvider
|
| 10 |
+
from proard.classification.data_providers.cifar10 import Cifar10DataProvider
|
| 11 |
+
from proard.classification.data_providers.cifar100 import Cifar100DataProvider
|
| 12 |
+
from proard.classification.run_manager import ClassificationRunConfig, RunManager,DistributedRunManager
|
| 13 |
+
from proard.model_zoo import DYN_net
|
| 14 |
+
from proard.nas.accuracy_predictor import AccuracyDataset,AccuracyPredictor,ResNetArchEncoder,RobustnessPredictor,MobileNetArchEncoder,AccuracyRobustnessDataset,Accuracy_Robustness_Predictor
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"-p", "--path", help="The path of imagenet", type=str, default="/dataset/imagenet"
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"-b",
|
| 23 |
+
"--batch-size",
|
| 24 |
+
help="The batch on every device for validation",
|
| 25 |
+
type=int,
|
| 26 |
+
default=128,
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"-n",
|
| 31 |
+
"--net",
|
| 32 |
+
metavar="DYNNET",
|
| 33 |
+
default="MBV3",
|
| 34 |
+
choices=[
|
| 35 |
+
"ResNet50",
|
| 36 |
+
"MBV3",
|
| 37 |
+
"ProxylessNASNet",
|
| 38 |
+
"MBV2"
|
| 39 |
+
],
|
| 40 |
+
help="dynamic networks",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--robust_mode", type=bool, default=True
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--WPS", type=bool, default=False
|
| 51 |
+
)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
if args.gpu == "all":
|
| 54 |
+
device_list = range(torch.cuda.device_count())
|
| 55 |
+
args.gpu = ",".join(str(_) for _ in device_list)
|
| 56 |
+
else:
|
| 57 |
+
device_list = [int(_) for _ in args.gpu.split(",")]
|
| 58 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
| 59 |
+
args.batch_size = args.batch_size * max(len(device_list), 1)
|
| 60 |
+
ImagenetDataProvider.DEFAULT_PATH = args.path
|
| 61 |
+
|
| 62 |
+
run_config = ClassificationRunConfig(dataset= args.dataset, test_batch_size=args.batch_size, n_worker=args.workers,robust_mode=args.robust_mode)
|
| 63 |
+
dyn_network = DYN_net(args.net,args.robust_mode,args.dataset, args.train_criterion ,pretrained=True,run_config=run_config,WPS=args.WPS)
|
| 64 |
+
""" Randomly sample a sub-network,
|
| 65 |
+
you can also manually set the sub-network using:
|
| 66 |
+
dyn_network.set_active_subnet(ks=7, e=6, d=4)
|
| 67 |
+
"""
|
| 68 |
+
# dyn_network.set_active_subnet(ks=3, e=3, d=2)
|
| 69 |
+
# dyn_network.set_active_subnet(d=4,e=0.25,w=1)
|
| 70 |
+
import random
|
| 71 |
+
import numpy as np
|
| 72 |
+
random.seed(0)
|
| 73 |
+
np.random.seed(0)
|
| 74 |
+
acc1,rob1,acc2,rob2 =[],[],[],[]
|
| 75 |
+
if args.net == "ResNet50":
|
| 76 |
+
arch = ResNetArchEncoder(image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[0,1,2],expand_list=[0.2,0.25,0.35],width_mult_list=[0.65,0.8,1.0])
|
| 77 |
+
else:
|
| 78 |
+
arch = MobileNetArchEncoder (image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[2,3,4],expand_list=[3,4,6],ks_list=[3,5,7])
|
| 79 |
+
print(arch)
|
| 80 |
+
acc_data = AccuracyRobustnessDataset("./acc_rob_data_{}_{}_{}".format(args.dataset,args.net,args.train_criterion))
|
| 81 |
+
train_loader, valid_loader, base_acc ,base_rob = acc_data.build_acc_data_loader(arch)
|
| 82 |
+
for inputs, targets_acc, targets_rob in train_loader:
|
| 83 |
+
for i in range(len(targets_acc)):
|
| 84 |
+
acc1.append(targets_acc[i].item() * 100)
|
| 85 |
+
rob1.append(targets_rob[i].item() * 100)
|
| 86 |
+
|
| 87 |
+
np.save("./results/acc_mbv3.npy",np.array(acc1))
|
| 88 |
+
np.save("./results/rob_mbv3.npy",np.array(rob1))
|
| 89 |
+
|
search_best.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch
|
| 10 |
+
import random
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.optim import *
|
| 13 |
+
from torch.optim.lr_scheduler import *
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from torchprofile import profile_macs
|
| 16 |
+
from torchvision.datasets import *
|
| 17 |
+
from torchvision.transforms import *
|
| 18 |
+
from proard.model_zoo import DYN_net
|
| 19 |
+
from proard.nas.accuracy_predictor import AccuracyPredictor,ResNetArchEncoder,RobustnessPredictor,MobileNetArchEncoder,Accuracy_Robustness_Predictor
|
| 20 |
+
from proard.nas.efficiency_predictor import ResNet50FLOPsModel,Mbv3FLOPsModel,ProxylessNASFLOPsModel
|
| 21 |
+
from proard.nas.search_algorithm import EvolutionFinder,DynIndividual_mbv,DynIndividual_res,DynRandomSampler,DynProblem_mbv,DynProblem_res,DynSampling,individual_to_arch_res,individual_to_arch_mbv
|
| 22 |
+
from utils.profile import trainable_param_num
|
| 23 |
+
from pymoo.core.individual import Individual
|
| 24 |
+
from pymoo.core.mutation import Mutation
|
| 25 |
+
from pymoo.core.population import Population
|
| 26 |
+
from pymoo.core.problem import Problem
|
| 27 |
+
from pymoo.core.sampling import Sampling
|
| 28 |
+
from pymoo.core.variable import Choice
|
| 29 |
+
from pymoo.operators.crossover.ux import UniformCrossover
|
| 30 |
+
from pymoo.operators.mutation.pm import PolynomialMutation
|
| 31 |
+
from pymoo.operators.mutation.rm import ChoiceRandomMutation
|
| 32 |
+
from pymoo.operators.selection.rnd import RandomSelection
|
| 33 |
+
from pymoo.operators.selection.tournament import TournamentSelection
|
| 34 |
+
from pymoo.algorithms.moo.nsga2 import NSGA2
|
| 35 |
+
from pymoo.algorithms.moo.sms import SMSEMOA
|
| 36 |
+
from pymoo.algorithms.moo.spea2 import SPEA2
|
| 37 |
+
from pymoo.optimize import minimize
|
| 38 |
+
from pymoo.termination import get_termination
|
| 39 |
+
from pymoo.termination.default import DefaultMultiObjectiveTermination
|
| 40 |
+
from pymoo.core.callback import Callback
|
| 41 |
+
from pymoo.util.display.column import Column
|
| 42 |
+
from pymoo.util.display.output import Output
|
| 43 |
+
from proard.classification.run_manager import ClassificationRunConfig, RunManager
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"-p", "--path", help="The path of cifar10", type=str, default="/dataset/cifar10"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"-b",
|
| 51 |
+
"--batch-size",
|
| 52 |
+
help="The batch on every device for validation",
|
| 53 |
+
type=int,
|
| 54 |
+
default=100,
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"-n",
|
| 59 |
+
"--net",
|
| 60 |
+
metavar="DYNNET",
|
| 61 |
+
default="ResNet50",
|
| 62 |
+
choices=[
|
| 63 |
+
"ResNet50",
|
| 64 |
+
"MBV3",
|
| 65 |
+
"ProxylessNASNet",
|
| 66 |
+
],
|
| 67 |
+
help="dynamic networks",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--attack", type=str, default="linf-pgd" ,choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce']
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--robust_mode", type=bool, default=True
|
| 78 |
+
)
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
if args.gpu == "all":
|
| 81 |
+
device_list = range(torch.cuda.device_count())
|
| 82 |
+
args.gpu = ",".join(str(_) for _ in device_list)
|
| 83 |
+
else:
|
| 84 |
+
device_list = [int(_) for _ in args.gpu.split(",")]
|
| 85 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
| 86 |
+
args.batch_size = args.batch_size * max(len(device_list), 1)
|
| 87 |
+
run_config = ClassificationRunConfig(attack_type=args.attack, dataset= args.dataset, test_batch_size=args.batch_size, n_worker=args.workers,robust_mode=args.robust_mode)
|
| 88 |
+
dyn_network = DYN_net(args.net,args.robust_mode,args.dataset,args.train_criterion, pretrained=True,run_config=run_config)
|
| 89 |
+
if args.net == "ResNet50":
|
| 90 |
+
efficiency_predictor = ResNet50FLOPsModel(dyn_network)
|
| 91 |
+
arch = ResNetArchEncoder(image_size_list=[32],depth_list=[0,1,2],expand_list=[0.2,0.25,0.35],width_mult_list=[0.65,0.8,1.0])
|
| 92 |
+
accuracy_robustness_predictor = Accuracy_Robustness_Predictor(arch)
|
| 93 |
+
accuracy_robustness_predictor.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
|
| 94 |
+
elif args.net == "MBV3":
|
| 95 |
+
efficiency_predictor = Mbv3FLOPsModel(dyn_network)
|
| 96 |
+
arch = MobileNetArchEncoder(image_size_list=[32],depth_list=[2,3,4],expand_list=[3,4,6],ks_list=[3,5,7])
|
| 97 |
+
accuracy_robustness_predictor = Accuracy_Robustness_Predictor(arch)
|
| 98 |
+
accuracy_robustness_predictor.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
|
| 99 |
+
elif args.net == "ProxylessNASNet":
|
| 100 |
+
efficiency_predictor = ProxylessNASFLOPsModel(dyn_network)
|
| 101 |
+
arch = MobileNetArchEncoder(image_size_list=[32],depth_list=[2,3,4],expand_list=[3,4,6],width_mult_list=[3,5,7])
|
| 102 |
+
accuracy_robustness_predictor = Accuracy_Robustness_Predictor(arch)
|
| 103 |
+
accuracy_robustness_predictor.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
|
| 104 |
+
##### Test #################################################
|
| 105 |
+
dyn_sampler = DynRandomSampler(arch, efficiency_predictor)
|
| 106 |
+
# arch1, eff1 = dyn_sampler.random_sample()
|
| 107 |
+
# arch2, eff2 = dyn_sampler.random_sample()
|
| 108 |
+
# print(accuracy_predictor.predict_acc([arch1, arch2]))
|
| 109 |
+
# print(arch1,eff1)
|
| 110 |
+
##################################################
|
| 111 |
+
|
| 112 |
+
""" Hyperparameters
|
| 113 |
+
- P: size of the population in each generation (number of individuals)
|
| 114 |
+
- N: number of generations to run the algorithm
|
| 115 |
+
- mutate_prob: probability of gene mutation in the evolutionary search
|
| 116 |
+
"""
|
| 117 |
+
P = 100
|
| 118 |
+
N = 100
|
| 119 |
+
mutation_prob = 0.5
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# variables options
|
| 123 |
+
if args.net == 'ResNet50':
|
| 124 |
+
search_space = {
|
| 125 |
+
'e': [0.2, 0.25, 0.35],
|
| 126 |
+
'd': [0, 1, 2],
|
| 127 |
+
'w': [0 ,1 ,2],
|
| 128 |
+
'image_size': [32]
|
| 129 |
+
}
|
| 130 |
+
else:
|
| 131 |
+
search_space = {
|
| 132 |
+
'ks': [3, 5, 7],
|
| 133 |
+
'e': [3, 4, 6],
|
| 134 |
+
'd': [2, 3, 4],
|
| 135 |
+
'image_size': [32]
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
#----------------------------
|
| 139 |
+
# units
|
| 140 |
+
num_blocks = arch.max_n_blocks
|
| 141 |
+
num_stages = arch.n_stage
|
| 142 |
+
Flops_constraints = 1600
|
| 143 |
+
if args.net == "ResNet50":
|
| 144 |
+
problem = DynProblem_res(efficiency_predictor, accuracy_robustness_predictor, num_blocks, num_stages, search_space,Flops_constraints)
|
| 145 |
+
else:
|
| 146 |
+
problem = DynProblem_mbv(efficiency_predictor, accuracy_robustness_predictor, num_blocks, num_stages, search_space,Flops_constraints)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
mutation_rc = ChoiceRandomMutation(prob=1.0, prob_var=0.1)
|
| 153 |
+
crossover_ux = UniformCrossover(prob=1.0)
|
| 154 |
+
# selection_tournament = TournamentSelection(
|
| 155 |
+
# func_comp=accuracy_predictor.predict_acc,
|
| 156 |
+
# pressure=2
|
| 157 |
+
# )
|
| 158 |
+
termination_default = DefaultMultiObjectiveTermination(
|
| 159 |
+
xtol=1e-8, cvtol=1e-6, ftol=0.0025, period=30, n_max_gen=1000, n_max_evals=100000
|
| 160 |
+
)
|
| 161 |
+
termination_gen = get_termination("n_gen", N)
|
| 162 |
+
np.random.seed(42)
|
| 163 |
+
random.seed(42)
|
| 164 |
+
if args.net=="ResNet50":
|
| 165 |
+
init_pop = Population(individuals=[DynIndividual_res(dyn_sampler.random_sample(), accuracy_robustness_predictor) for _ in range(P)])
|
| 166 |
+
else:
|
| 167 |
+
init_pop = Population(individuals=[DynIndividual_mbv(dyn_sampler.random_sample(), accuracy_robustness_predictor) for _ in range(P)])
|
| 168 |
+
|
| 169 |
+
algorithm = NSGA2(
|
| 170 |
+
pop_size=P,
|
| 171 |
+
sampling=DynSampling(),
|
| 172 |
+
# selection=selection_tournament,
|
| 173 |
+
crossover=crossover_ux,
|
| 174 |
+
mutation=mutation_rc,
|
| 175 |
+
# mutation=mutation_pm,
|
| 176 |
+
# survival=RankAndCrowdingSurvival(),
|
| 177 |
+
# output=MultiObjectiveOutput(),
|
| 178 |
+
# **kwargs
|
| 179 |
+
)
|
| 180 |
+
res_nsga2 = minimize(
|
| 181 |
+
problem,
|
| 182 |
+
algorithm,
|
| 183 |
+
termination=termination_gen,
|
| 184 |
+
seed=1,
|
| 185 |
+
#verbose=True,
|
| 186 |
+
verbose=False,
|
| 187 |
+
save_history=True,
|
| 188 |
+
)
|
| 189 |
+
# print(100-res_nsga2.history[99].pop.get('F')[:,0],100-res_nsga2.history[99].pop.get('F')[:,1])
|
| 190 |
+
# a = individual_to_arch_res(res_nsga2.pop.get('X'),num_blocks)[0]
|
| 191 |
+
# # print(a)
|
| 192 |
+
# # a['d'][3] = int(a['d'][3])
|
| 193 |
+
# a['d'][4] = int(a['d'][4])
|
| 194 |
+
# dyn_network.set_active_subnet(**a)
|
| 195 |
+
# subnet = dyn_network.get_active_subnet(preserve_weight=True)
|
| 196 |
+
# run_manager = RunManager(".tmp/eval_subnet", subnet, run_config, init=False)
|
| 197 |
+
# run_config.data_provider.assign_active_img_size(32)
|
| 198 |
+
# run_manager.reset_running_statistics(net=subnet)
|
| 199 |
+
|
| 200 |
+
# print("Test random subnet:")
|
| 201 |
+
# # print(subnet.module_str)
|
| 202 |
+
|
| 203 |
+
# loss, (top1, top5,robust1,robust5) = run_manager.validate(net=subnet,is_test=True)
|
| 204 |
+
# print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f,\t robust1=%.1f,\t robust5=%.1f" % (loss, top1, top5,robust1,robust5))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
np.savetxt("./results/acc_gen0.csv", 100-res_nsga2.history[0].pop.get('F')[:,0], delimiter=",")
|
| 208 |
+
|
| 209 |
+
np.savetxt("./results/acc_gen99.csv", 100-res_nsga2.history[99].pop.get('F')[:,0], delimiter=",")
|
| 210 |
+
np.savetxt("./results/rob_gen0.csv", 100-res_nsga2.history[0].pop.get('F')[:,1], delimiter=",")
|
| 211 |
+
|
| 212 |
+
np.savetxt("./results/rob_gen99.csv", 100-res_nsga2.history[99].pop.get('F')[:,1], delimiter=",")
|
| 213 |
+
np.savetxt("./results/flops_gen99.csv", res_nsga2.history[99].pop.get('G'), delimiter=",")
|
| 214 |
+
|
| 215 |
+
# np.savetxt("./results/robs.csv", np.array(robs), delimiter=",")
|
| 216 |
+
|
| 217 |
+
from matplotlib import pyplot as plt
|
| 218 |
+
from matplotlib.ticker import FormatStrFormatter
|
| 219 |
+
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
|
| 220 |
+
# NSGA-II population progression
|
| 221 |
+
x_min, x_max, y_min, y_max = 80, 93, 47, 56
|
| 222 |
+
ax_limits = [x_min, x_max, y_min, y_max]
|
| 223 |
+
#-------------------------------------------------
|
| 224 |
+
# plot
|
| 225 |
+
fig, ax = plt.subplots(dpi=600)
|
| 226 |
+
gen0 = 0
|
| 227 |
+
gen1 = 99
|
| 228 |
+
print(100-res_nsga2.history[gen1].pop.get('F')[:,0], 100 - res_nsga2.history[gen1].pop.get('F')[:,1])
|
| 229 |
+
# gen2 = 99
|
| 230 |
+
# print(res_nsga2.history[gen0].pop.get('F')[:,0],res_nsga2.history[gen0].pop.get('F')[:,1] )
|
| 231 |
+
ax.plot(100-res_nsga2.history[gen0].pop.get('F')[:,0], 100 - res_nsga2.history[gen0].pop.get('F')[:,1] , 'o', label=f'Population at generation #{gen0+1}', color='red', alpha=0.5)
|
| 232 |
+
ax.plot(100-res_nsga2.history[gen1].pop.get('F')[:,0], 100 - res_nsga2.history[gen1].pop.get('F')[:,1] , 'o', label=f'Population at generation #{gen1+1}', color='green', alpha=0.5)
|
| 233 |
+
# ax.plot(res_nsga2.history[gen2].pop.get('F')[:,0], 100 - res_nsga2.history[gen2].pop.get('F')[:,1], 'o', label=f'Population at generation #{gen2+1}', color='orange', alpha=0.5)
|
| 234 |
+
# ax.plot(res_nsga2.history[gen3].pop.get('F')[:,0], 100 - res_nsga2.history[gen3].pop.get('F')[:,1], 'o', label=f'Population at generation #{gen3+1}', color='blue', alpha=0.5)
|
| 235 |
+
#-------------------------------------------------
|
| 236 |
+
# text
|
| 237 |
+
ax.grid(True, linestyle=':')
|
| 238 |
+
ax.set_xlabel('Accuracy (%)')
|
| 239 |
+
ax.set_ylabel('Robustness (%)')
|
| 240 |
+
ax.set_title('NSGA-II solutions progression For Fixed number of FLOPs'),
|
| 241 |
+
ax.legend()
|
| 242 |
+
#-------------------------------------------------
|
| 243 |
+
# x-axis
|
| 244 |
+
ax.xaxis.set_major_locator(MultipleLocator(1))
|
| 245 |
+
ax.xaxis.set_minor_locator(MultipleLocator(1))
|
| 246 |
+
# y-axis
|
| 247 |
+
ax.yaxis.set_major_locator(MultipleLocator(1))
|
| 248 |
+
ax.yaxis.set_minor_locator(MultipleLocator(1))
|
| 249 |
+
# ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
|
| 250 |
+
ax.set(xlim=(ax_limits[0], ax_limits[1]), ylim=(ax_limits[2], ax_limits[3]))
|
| 251 |
+
#-------------------------------------------------
|
| 252 |
+
plt.savefig('nsga2_pop_progression_debug.png')
|
| 253 |
+
fig.set_dpi(100)
|
| 254 |
+
# plt.close(fig)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# plt.show()
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# finder = EvolutionFinder(efficiency_predictor,accuracy_predictor,Robustness_predictor)
|
| 262 |
+
# valid_constraint_range = 800
|
| 263 |
+
# best_valids, best_info = finder.run_evolution_search(constraint=valid_constraint_range,verbose=True)
|
| 264 |
+
# print(efficiency_predictor.get_efficiency(best_info[2]))
|
| 265 |
+
# dyn_network.set_active_subnet(best_info[2]['d'],best_info[2]['e'],best_info[2]['w'])
|
| 266 |
+
# subnet = dyn_network.get_active_subnet(preserve_weight=True)
|
| 267 |
+
# run_config = CifarRunConfig_robust(test_batch_size=args.batch_size, n_worker=args.workers)
|
| 268 |
+
# run_manager = RunManager_robust(".tmp/eval_subnet", subnet, run_config, init=False)
|
| 269 |
+
# run_config.data_provider.assign_active_img_size(32)
|
| 270 |
+
# run_manager.reset_running_statistics(net=subnet)
|
| 271 |
+
# loss, (top1, top5,robust1,robust5) = run_manager.validate(net=subnet)
|
| 272 |
+
# print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f,\t robust1=%.1f,\t robust5=%.1f" % (loss, top1, top5,robust1,robust5))
|
| 273 |
+
# print("number of parameter={}M".format(trainable_param_num(subnet)))
|
train_ofa_net.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
# using for distributed training
|
| 11 |
+
import horovod.torch as hvd
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from proard.classification.elastic_nn.modules.dynamic_op import (
|
| 16 |
+
DynamicSeparableConv2d,
|
| 17 |
+
)
|
| 18 |
+
from proard.classification.elastic_nn.networks import DYNMobileNetV3,DYNProxylessNASNets,DYNResNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar
|
| 19 |
+
from proard.classification.run_manager import DistributedClassificationRunConfig
|
| 20 |
+
from proard.classification.run_manager.distributed_run_manager import (
|
| 21 |
+
DistributedRunManager
|
| 22 |
+
)
|
| 23 |
+
from proard.utils import download_url, MyRandomResizedCrop
|
| 24 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import load_models
|
| 25 |
+
|
| 26 |
+
parser = argparse.ArgumentParser()
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--task",
|
| 29 |
+
type=str,
|
| 30 |
+
default="expand",
|
| 31 |
+
choices=[
|
| 32 |
+
"kernel", # for architecture except ResNet
|
| 33 |
+
"depth",
|
| 34 |
+
"expand",
|
| 35 |
+
"width", # only for ResNet
|
| 36 |
+
],
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument("--phase", type=int, default=2, choices=[1, 2])
|
| 39 |
+
parser.add_argument("--resume", action="store_true")
|
| 40 |
+
parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet","MBV2"])
|
| 41 |
+
parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar10", "cifar100", "imagenet"])
|
| 42 |
+
parser.add_argument("--robust_mode", type=bool, default=True)
|
| 43 |
+
parser.add_argument("--epsilon", type=float, default=0.031)
|
| 44 |
+
parser.add_argument("--num_steps", type=int, default=10)
|
| 45 |
+
parser.add_argument("--step_size", type=float, default=0.0078)
|
| 46 |
+
parser.add_argument("--clip_min", type=int, default=0)
|
| 47 |
+
parser.add_argument("--clip_max", type=int, default=1)
|
| 48 |
+
parser.add_argument("--const_init", type=bool, default=False)
|
| 49 |
+
parser.add_argument("--beta", type=float, default=6.0)
|
| 50 |
+
parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"])
|
| 51 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 52 |
+
parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"])
|
| 53 |
+
parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"])
|
| 54 |
+
parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce'])
|
| 55 |
+
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
if args.model_name == "ResNet50":
|
| 58 |
+
args.ks_list = "3"
|
| 59 |
+
if args.task == "width":
|
| 60 |
+
if args.robust_mode:
|
| 61 |
+
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2width"
|
| 62 |
+
else:
|
| 63 |
+
args.path = "exp/"+ args.dataset + '/' +args.model_name +'/' + args.train_criterion +"/normal2width"
|
| 64 |
+
args.dynamic_batch_size = 1
|
| 65 |
+
args.n_epochs = 120
|
| 66 |
+
args.base_lr = 3e-2
|
| 67 |
+
args.warmup_epochs = 5
|
| 68 |
+
args.warmup_lr = -1
|
| 69 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 70 |
+
args.expand_list = "0.35"
|
| 71 |
+
args.depth_list = "2"
|
| 72 |
+
elif args.task == "depth":
|
| 73 |
+
if args.robust_mode:
|
| 74 |
+
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
|
| 75 |
+
else:
|
| 76 |
+
args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
|
| 77 |
+
args.dynamic_batch_size = 2
|
| 78 |
+
if args.phase == 1:
|
| 79 |
+
args.n_epochs = 25
|
| 80 |
+
args.base_lr = 2.5e-3
|
| 81 |
+
args.warmup_epochs = 0
|
| 82 |
+
args.warmup_lr = -1
|
| 83 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 84 |
+
args.expand_list ="0.35"
|
| 85 |
+
args.depth_list = "1,2"
|
| 86 |
+
else:
|
| 87 |
+
args.n_epochs = 120
|
| 88 |
+
args.base_lr = 7.5e-3
|
| 89 |
+
args.warmup_epochs = 5
|
| 90 |
+
args.warmup_lr = -1
|
| 91 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 92 |
+
args.expand_list = "0.35"
|
| 93 |
+
args.depth_list = "0,1,2"
|
| 94 |
+
elif args.task == "expand":
|
| 95 |
+
if args.robust_mode :
|
| 96 |
+
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
|
| 97 |
+
else:
|
| 98 |
+
args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
|
| 99 |
+
args.dynamic_batch_size = 4
|
| 100 |
+
if args.phase == 1:
|
| 101 |
+
args.n_epochs = 25
|
| 102 |
+
args.base_lr = 2.5e-3
|
| 103 |
+
args.warmup_epochs = 0
|
| 104 |
+
args.warmup_lr = -1
|
| 105 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 106 |
+
args.expand_list = "0.25,0.35"
|
| 107 |
+
args.depth_list = "0,1,2"
|
| 108 |
+
else:
|
| 109 |
+
args.n_epochs = 120
|
| 110 |
+
args.base_lr = 7.5e-3
|
| 111 |
+
args.warmup_epochs = 5
|
| 112 |
+
args.warmup_lr = -1
|
| 113 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 114 |
+
args.expand_list = "0.2,0.25,0.35"
|
| 115 |
+
args.depth_list = "0,1,2"
|
| 116 |
+
else:
|
| 117 |
+
raise NotImplementedError
|
| 118 |
+
else:
|
| 119 |
+
args.width_mult_list = "1.0"
|
| 120 |
+
if args.task == "kernel":
|
| 121 |
+
if args.robust_mode:
|
| 122 |
+
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
|
| 123 |
+
else:
|
| 124 |
+
args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
|
| 125 |
+
args.dynamic_batch_size = 1
|
| 126 |
+
args.n_epochs = 120
|
| 127 |
+
args.base_lr = 3e-2
|
| 128 |
+
args.warmup_epochs = 5
|
| 129 |
+
args.warmup_lr = -1
|
| 130 |
+
args.ks_list = "3,5,7"
|
| 131 |
+
args.expand_list = "6"
|
| 132 |
+
args.depth_list = "4"
|
| 133 |
+
elif args.task == "depth":
|
| 134 |
+
if args.robust_mode :
|
| 135 |
+
args.path = "exp/robust/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
|
| 136 |
+
else:
|
| 137 |
+
args.path = "exp/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
|
| 138 |
+
args.dynamic_batch_size = 2
|
| 139 |
+
if args.phase == 1:
|
| 140 |
+
args.n_epochs = 25
|
| 141 |
+
args.base_lr = 2.5e-3
|
| 142 |
+
args.warmup_epochs = 0
|
| 143 |
+
args.warmup_lr = -1
|
| 144 |
+
args.ks_list = "3,5,7"
|
| 145 |
+
args.expand_list = "6"
|
| 146 |
+
args.depth_list = "3,4"
|
| 147 |
+
else:
|
| 148 |
+
args.n_epochs = 120
|
| 149 |
+
args.base_lr = 7.5e-3
|
| 150 |
+
args.warmup_epochs = 5
|
| 151 |
+
args.warmup_lr = -1
|
| 152 |
+
args.ks_list = "3,5,7"
|
| 153 |
+
args.expand_list = "6"
|
| 154 |
+
args.depth_list = "2,3,4"
|
| 155 |
+
elif args.task == "expand":
|
| 156 |
+
if args.robust_mode:
|
| 157 |
+
args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase%d" % args.phase
|
| 158 |
+
else:
|
| 159 |
+
args.path = "exp/"+ args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/kernel_depth2kernel_depth_width/phase%d" % args.phase
|
| 160 |
+
args.dynamic_batch_size = 4
|
| 161 |
+
if args.phase == 1:
|
| 162 |
+
args.n_epochs = 25
|
| 163 |
+
args.base_lr = 2.5e-3
|
| 164 |
+
args.warmup_epochs = 0
|
| 165 |
+
args.warmup_lr = -1
|
| 166 |
+
args.ks_list = "3,5,7"
|
| 167 |
+
args.expand_list = "4,6"
|
| 168 |
+
args.depth_list = "2,3,4"
|
| 169 |
+
else:
|
| 170 |
+
args.n_epochs = 120
|
| 171 |
+
args.base_lr = 7.5e-3
|
| 172 |
+
args.warmup_epochs = 5
|
| 173 |
+
args.warmup_lr = -1
|
| 174 |
+
args.ks_list = "3,5,7"
|
| 175 |
+
args.expand_list = "3,4,6"
|
| 176 |
+
args.depth_list = "2,3,4"
|
| 177 |
+
else:
|
| 178 |
+
raise NotImplementedError
|
| 179 |
+
args.manual_seed = 0
|
| 180 |
+
|
| 181 |
+
args.lr_schedule_type = "cosine"
|
| 182 |
+
|
| 183 |
+
args.base_batch_size = 64
|
| 184 |
+
args.valid_size = 64
|
| 185 |
+
|
| 186 |
+
args.opt_type = "sgd"
|
| 187 |
+
args.momentum = 0.9
|
| 188 |
+
args.no_nesterov = False
|
| 189 |
+
args.weight_decay = 3e-5
|
| 190 |
+
args.label_smoothing = 0.1
|
| 191 |
+
args.no_decay_keys = "bn#bias"
|
| 192 |
+
args.fp16_allreduce = False
|
| 193 |
+
|
| 194 |
+
args.model_init = "he_fout"
|
| 195 |
+
args.validation_frequency = 1
|
| 196 |
+
args.print_frequency = 10
|
| 197 |
+
|
| 198 |
+
args.n_worker = 8
|
| 199 |
+
args.resize_scale = 0.08
|
| 200 |
+
args.distort_color = "tf"
|
| 201 |
+
if args.dataset == "imagenet":
|
| 202 |
+
args.image_size = "128,160,192,224"
|
| 203 |
+
else:
|
| 204 |
+
args.image_size = "32"
|
| 205 |
+
args.continuous_size = True
|
| 206 |
+
args.not_sync_distributed_image_size = False
|
| 207 |
+
|
| 208 |
+
args.bn_momentum = 0.1
|
| 209 |
+
args.bn_eps = 1e-5
|
| 210 |
+
args.dropout = 0.1
|
| 211 |
+
args.base_stage_width = "google"
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
args.dy_conv_scaling_mode = 1
|
| 215 |
+
args.independent_distributed_sampling = False
|
| 216 |
+
|
| 217 |
+
args.kd_ratio = 1.0
|
| 218 |
+
args.kd_type = "ce"
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
os.makedirs(args.path, exist_ok=True)
|
| 223 |
+
|
| 224 |
+
# Initialize Horovod
|
| 225 |
+
hvd.init()
|
| 226 |
+
# Pin GPU to be used to process local rank (one GPU per process)
|
| 227 |
+
torch.cuda.set_device(hvd.local_rank())
|
| 228 |
+
if args.robust_mode:
|
| 229 |
+
args.teacher_path = 'exp/robust/teacher/' + args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 230 |
+
else:
|
| 231 |
+
args.teacher_path = 'exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 232 |
+
num_gpus = hvd.size()
|
| 233 |
+
|
| 234 |
+
torch.manual_seed(args.manual_seed)
|
| 235 |
+
torch.cuda.manual_seed_all(args.manual_seed)
|
| 236 |
+
np.random.seed(args.manual_seed)
|
| 237 |
+
random.seed(args.manual_seed)
|
| 238 |
+
|
| 239 |
+
# image size
|
| 240 |
+
args.image_size = [int(img_size) for img_size in args.image_size.split(",")]
|
| 241 |
+
if len(args.image_size) == 1:
|
| 242 |
+
args.image_size = args.image_size[0]
|
| 243 |
+
MyRandomResizedCrop.CONTINUOUS = args.continuous_size
|
| 244 |
+
MyRandomResizedCrop.SYNC_DISTRIBUTED = not args.not_sync_distributed_image_size
|
| 245 |
+
|
| 246 |
+
# build run config from args
|
| 247 |
+
args.lr_schedule_param = None
|
| 248 |
+
args.opt_param = {
|
| 249 |
+
"momentum": args.momentum,
|
| 250 |
+
"nesterov": not args.no_nesterov,
|
| 251 |
+
}
|
| 252 |
+
args.init_lr = args.base_lr * num_gpus # linearly rescale the learning rate
|
| 253 |
+
if args.warmup_lr < 0:
|
| 254 |
+
args.warmup_lr = args.base_lr
|
| 255 |
+
args.train_batch_size = args.base_batch_size
|
| 256 |
+
args.test_batch_size = args.base_batch_size * 4
|
| 257 |
+
run_config = DistributedClassificationRunConfig(
|
| 258 |
+
**args.__dict__, num_replicas=num_gpus, rank=hvd.rank()
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# print run config information
|
| 262 |
+
if hvd.rank() == 0:
|
| 263 |
+
print("Run config:")
|
| 264 |
+
for k, v in run_config.config.items():
|
| 265 |
+
print("\t%s: %s" % (k, v))
|
| 266 |
+
|
| 267 |
+
if args.dy_conv_scaling_mode == -1:
|
| 268 |
+
args.dy_conv_scaling_mode = None
|
| 269 |
+
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode
|
| 270 |
+
|
| 271 |
+
# build net from args
|
| 272 |
+
args.width_mult_list = [
|
| 273 |
+
float(width_mult) for width_mult in args.width_mult_list.split(",")
|
| 274 |
+
]
|
| 275 |
+
args.ks_list = [int(ks) for ks in args.ks_list.split(",")]
|
| 276 |
+
if args.model_name == "ResNet50":
|
| 277 |
+
args.expand_list = [float(e) for e in args.expand_list.split(",")]
|
| 278 |
+
else:
|
| 279 |
+
args.expand_list = [int(e) for e in args.expand_list.split(",")]
|
| 280 |
+
args.depth_list = [int(d) for d in args.depth_list.split(",")]
|
| 281 |
+
|
| 282 |
+
args.width_mult_list = (
|
| 283 |
+
args.width_mult_list[0]
|
| 284 |
+
if len(args.width_mult_list) == 1
|
| 285 |
+
else args.width_mult_list
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if args.model_name == "ResNet50":
|
| 289 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 290 |
+
net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes,
|
| 291 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 292 |
+
dropout_rate=args.dropout,
|
| 293 |
+
depth_list=args.depth_list,
|
| 294 |
+
expand_ratio_list=args.expand_list,
|
| 295 |
+
width_mult_list=args.width_mult_list,)
|
| 296 |
+
else:
|
| 297 |
+
net = DYNResNets( n_classes=run_config.data_provider.n_classes,
|
| 298 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 299 |
+
dropout_rate=args.dropout,
|
| 300 |
+
depth_list=args.depth_list,
|
| 301 |
+
expand_ratio_list=args.expand_list,
|
| 302 |
+
width_mult_list=args.width_mult_list,)
|
| 303 |
+
elif args.model_name == "MBV3":
|
| 304 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 305 |
+
net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 306 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 307 |
+
else:
|
| 308 |
+
net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 309 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 310 |
+
elif args.model_name == "ProxylessNASNet":
|
| 311 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 312 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 313 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 314 |
+
else:
|
| 315 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 316 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 317 |
+
elif args.model_name == "MBV2":
|
| 318 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 319 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 320 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
|
| 321 |
+
else:
|
| 322 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 323 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
|
| 324 |
+
else:
|
| 325 |
+
raise NotImplementedError
|
| 326 |
+
# teacher model
|
| 327 |
+
if args.kd_ratio > 0:
|
| 328 |
+
|
| 329 |
+
if args.model_name =="ResNet50":
|
| 330 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 331 |
+
args.teacher_model = DYNResNets_Cifar(
|
| 332 |
+
n_classes=run_config.data_provider.n_classes,
|
| 333 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 334 |
+
dropout_rate=args.dropout,
|
| 335 |
+
depth_list=[2],
|
| 336 |
+
expand_ratio_list=[0.35],
|
| 337 |
+
width_mult_list=[1.0],
|
| 338 |
+
)
|
| 339 |
+
else:
|
| 340 |
+
args.teacher_model = DYNResNets(
|
| 341 |
+
n_classes=run_config.data_provider.n_classes,
|
| 342 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 343 |
+
dropout_rate=args.dropout,
|
| 344 |
+
depth_list=[2],
|
| 345 |
+
expand_ratio_list=[0.35],
|
| 346 |
+
width_mult_list=[1.0],
|
| 347 |
+
)
|
| 348 |
+
elif args.model_name =="MBV3":
|
| 349 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 350 |
+
args.teacher_model = DYNMobileNetV3_Cifar(
|
| 351 |
+
n_classes=run_config.data_provider.n_classes,
|
| 352 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 353 |
+
dropout_rate=0,
|
| 354 |
+
width_mult=1.0,
|
| 355 |
+
ks_list=[7],
|
| 356 |
+
expand_ratio_list=[6],
|
| 357 |
+
depth_list=[4]
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
args.teacher_model = DYNMobileNetV3(
|
| 361 |
+
n_classes=run_config.data_provider.n_classes,
|
| 362 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 363 |
+
dropout_rate=0,
|
| 364 |
+
width_mult=1.0,
|
| 365 |
+
ks_list=[7],
|
| 366 |
+
expand_ratio_list=[6],
|
| 367 |
+
depth_list=[4]
|
| 368 |
+
)
|
| 369 |
+
elif args.model_name == "ProxylessNASNet":
|
| 370 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 371 |
+
args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 372 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 373 |
+
dropout_rate=0,
|
| 374 |
+
width_mult=1.0,
|
| 375 |
+
ks_list=[7],
|
| 376 |
+
expand_ratio_list=[6],
|
| 377 |
+
depth_list=[4])
|
| 378 |
+
else:
|
| 379 |
+
args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
|
| 380 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 381 |
+
dropout_rate=0,
|
| 382 |
+
width_mult=1.0,
|
| 383 |
+
ks_list=[7],
|
| 384 |
+
expand_ratio_list=[6],
|
| 385 |
+
depth_list=[4])
|
| 386 |
+
elif args.model_name == "MBV2":
|
| 387 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 388 |
+
args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 389 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 390 |
+
dropout_rate=0,
|
| 391 |
+
width_mult=1.0,
|
| 392 |
+
ks_list=[7],
|
| 393 |
+
expand_ratio_list=[6],
|
| 394 |
+
depth_list=[4],base_stage_width=args.base_stage_width)
|
| 395 |
+
else:
|
| 396 |
+
args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
|
| 397 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 398 |
+
dropout_rate=0,
|
| 399 |
+
width_mult=1.0,
|
| 400 |
+
ks_list=[7],
|
| 401 |
+
expand_ratio_list=[6],
|
| 402 |
+
depth_list=[4],base_stage_width=args.base_stage_width)
|
| 403 |
+
args.teacher_model.cuda()
|
| 404 |
+
|
| 405 |
+
""" Distributed RunManager """
|
| 406 |
+
# Horovod: (optional) compression algorithm.
|
| 407 |
+
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
|
| 408 |
+
distributed_run_manager = DistributedRunManager(
|
| 409 |
+
args.path,
|
| 410 |
+
net,
|
| 411 |
+
run_config,
|
| 412 |
+
compression,
|
| 413 |
+
backward_steps=args.dynamic_batch_size,
|
| 414 |
+
is_root=(hvd.rank() == 0),
|
| 415 |
+
)
|
| 416 |
+
distributed_run_manager.save_config()
|
| 417 |
+
# hvd broadcast
|
| 418 |
+
distributed_run_manager.broadcast()
|
| 419 |
+
|
| 420 |
+
# load teacher net weights
|
| 421 |
+
if args.kd_ratio > 0:
|
| 422 |
+
load_models(
|
| 423 |
+
distributed_run_manager, args.teacher_model, model_path=args.teacher_path
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# training
|
| 427 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 428 |
+
validate,
|
| 429 |
+
train,
|
| 430 |
+
)
|
| 431 |
+
if args.model_name =="ResNet50":
|
| 432 |
+
validate_func_dict = {
|
| 433 |
+
"image_size_list": {224 if args.dataset == "imagenet" else 32}
|
| 434 |
+
if isinstance(args.image_size, int)
|
| 435 |
+
else sorted({160, 224}),
|
| 436 |
+
"width_mult_list": sorted({min(args.width_mult_list), max(args.width_mult_list)}),
|
| 437 |
+
"expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
|
| 438 |
+
"depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
|
| 439 |
+
}
|
| 440 |
+
else:
|
| 441 |
+
validate_func_dict = {
|
| 442 |
+
"image_size_list": {224 if args.dataset == "imagenet" else 32}
|
| 443 |
+
if isinstance(args.image_size, int)
|
| 444 |
+
else sorted({160, 224}),
|
| 445 |
+
"width_mult_list": [1.0],
|
| 446 |
+
"ks_list": sorted({min(args.ks_list), max(args.ks_list)}),
|
| 447 |
+
"expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
|
| 448 |
+
"depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
if args.task == "width":
|
| 452 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 453 |
+
train_elastic_width_mult,
|
| 454 |
+
)
|
| 455 |
+
if distributed_run_manager.start_epoch == 0:
|
| 456 |
+
if args.robust_mode:
|
| 457 |
+
args.dyn_checkpoint_path ='exp/robust/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 458 |
+
else:
|
| 459 |
+
args.dyn_checkpoint_path ='exp/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 460 |
+
load_models(
|
| 461 |
+
distributed_run_manager,
|
| 462 |
+
distributed_run_manager.net,
|
| 463 |
+
args.dyn_checkpoint_path,
|
| 464 |
+
)
|
| 465 |
+
distributed_run_manager.write_log(
|
| 466 |
+
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
|
| 467 |
+
% validate(distributed_run_manager, is_test=True, **validate_func_dict),
|
| 468 |
+
"valid",
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
assert args.resume
|
| 472 |
+
train_elastic_width_mult (train,distributed_run_manager,args,validate_func_dict)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
elif args.task == "kernel":
|
| 477 |
+
validate_func_dict["ks_list"] = sorted(args.ks_list)
|
| 478 |
+
if distributed_run_manager.start_epoch == 0:
|
| 479 |
+
if args.robust_mode:
|
| 480 |
+
args.dyn_checkpoint_path ='exp/robust/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 481 |
+
else:
|
| 482 |
+
args.dyn_checkpoint_path ='exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 483 |
+
load_models(
|
| 484 |
+
distributed_run_manager,
|
| 485 |
+
distributed_run_manager.net,
|
| 486 |
+
args.dyn_checkpoint_path,
|
| 487 |
+
)
|
| 488 |
+
distributed_run_manager.write_log(
|
| 489 |
+
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
|
| 490 |
+
% validate(distributed_run_manager, is_test=True, **validate_func_dict),
|
| 491 |
+
"valid",
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
assert args.resume
|
| 495 |
+
train(
|
| 496 |
+
distributed_run_manager,
|
| 497 |
+
args,
|
| 498 |
+
lambda _run_manager, epoch, is_test: validate(
|
| 499 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 500 |
+
),
|
| 501 |
+
)
|
| 502 |
+
elif args.task == "depth":
|
| 503 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 504 |
+
train_elastic_depth,
|
| 505 |
+
)
|
| 506 |
+
if args.robust_mode:
|
| 507 |
+
if args.model_name =="ResNet50":
|
| 508 |
+
if args.phase == 1:
|
| 509 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
|
| 510 |
+
else:
|
| 511 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 512 |
+
else:
|
| 513 |
+
if args.phase == 1:
|
| 514 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
|
| 515 |
+
else:
|
| 516 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 517 |
+
else :
|
| 518 |
+
if args.model_name =="ResNet50":
|
| 519 |
+
if args.phase == 1:
|
| 520 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
|
| 521 |
+
else:
|
| 522 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 523 |
+
else:
|
| 524 |
+
if args.phase == 1:
|
| 525 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
|
| 526 |
+
else:
|
| 527 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 528 |
+
train_elastic_depth(train, distributed_run_manager, args, validate_func_dict)
|
| 529 |
+
elif args.task == "expand":
|
| 530 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 531 |
+
train_elastic_expand,
|
| 532 |
+
)
|
| 533 |
+
if args.robust_mode :
|
| 534 |
+
if args.model_name =="ResNet50":
|
| 535 |
+
if args.phase == 1:
|
| 536 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 537 |
+
else:
|
| 538 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 539 |
+
else:
|
| 540 |
+
if args.phase == 1:
|
| 541 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 542 |
+
else:
|
| 543 |
+
args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 544 |
+
else:
|
| 545 |
+
if args.model_name =="ResNet50":
|
| 546 |
+
if args.phase == 1:
|
| 547 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 548 |
+
else:
|
| 549 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 550 |
+
else:
|
| 551 |
+
if args.phase == 1:
|
| 552 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 553 |
+
else:
|
| 554 |
+
args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 555 |
+
|
| 556 |
+
train_elastic_expand(train, distributed_run_manager, args, validate_func_dict)
|
| 557 |
+
else:
|
| 558 |
+
raise NotImplementedError
|
train_ofa_net_WPS.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
# using for distributed training
|
| 11 |
+
import horovod.torch as hvd
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from proard.classification.elastic_nn.modules.dynamic_op import (
|
| 16 |
+
DynamicSeparableConv2d,
|
| 17 |
+
)
|
| 18 |
+
from proard.classification.elastic_nn.networks import DYNMobileNetV3,DYNProxylessNASNets,DYNResNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar
|
| 19 |
+
from proard.classification.run_manager import DistributedClassificationRunConfig
|
| 20 |
+
from proard.classification.run_manager.distributed_run_manager import (
|
| 21 |
+
DistributedRunManager
|
| 22 |
+
)
|
| 23 |
+
from proard.utils import download_url, MyRandomResizedCrop
|
| 24 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 25 |
+
load_models,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--task",
|
| 31 |
+
type=str,
|
| 32 |
+
default="expand",
|
| 33 |
+
choices=[
|
| 34 |
+
"kernel", # for architecture except ResNet
|
| 35 |
+
"depth",
|
| 36 |
+
"expand",
|
| 37 |
+
"width", # only for ResNet
|
| 38 |
+
],
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument("--phase", type=int, default=2, choices=[1, 2])
|
| 41 |
+
parser.add_argument("--resume", action="store_true")
|
| 42 |
+
parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet"])
|
| 43 |
+
parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "cifar100", "imagenet"])
|
| 44 |
+
parser.add_argument("--robust_mode", type=bool, default=True)
|
| 45 |
+
parser.add_argument("--epsilon", type=float, default=0.031)
|
| 46 |
+
parser.add_argument("--num_steps", type=int, default=10)
|
| 47 |
+
parser.add_argument("--step_size", type=float, default=0.0078)
|
| 48 |
+
parser.add_argument("--clip_min", type=int, default=0)
|
| 49 |
+
parser.add_argument("--clip_max", type=int, default=1)
|
| 50 |
+
parser.add_argument("--const_init", type=bool, default=False)
|
| 51 |
+
parser.add_argument("--beta", type=float, default=6.0)
|
| 52 |
+
parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"])
|
| 53 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 54 |
+
parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"])
|
| 55 |
+
parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"])
|
| 56 |
+
parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce'])
|
| 57 |
+
args = parser.parse_args()
|
| 58 |
+
if args.model_name == "ResNet50":
|
| 59 |
+
args.ks_list = "3"
|
| 60 |
+
if args.task == "width":
|
| 61 |
+
if args.robust_mode:
|
| 62 |
+
args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2width"
|
| 63 |
+
else:
|
| 64 |
+
args.path = "exp/WPS"+ args.dataset + '/' +args.model_name +'/' + args.train_criterion +"/normal2width"
|
| 65 |
+
args.dynamic_batch_size = 1
|
| 66 |
+
args.n_epochs = 120
|
| 67 |
+
args.base_lr = 3e-2
|
| 68 |
+
args.warmup_epochs = 5
|
| 69 |
+
args.warmup_lr = -1
|
| 70 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 71 |
+
args.expand_list = "0.35"
|
| 72 |
+
args.depth_list = "2"
|
| 73 |
+
elif args.task == "depth":
|
| 74 |
+
if args.robust_mode:
|
| 75 |
+
args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
|
| 76 |
+
else:
|
| 77 |
+
args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
|
| 78 |
+
args.dynamic_batch_size = 2
|
| 79 |
+
if args.phase == 1:
|
| 80 |
+
args.n_epochs = 25
|
| 81 |
+
args.base_lr = 2.5e-3
|
| 82 |
+
args.warmup_epochs = 0
|
| 83 |
+
args.warmup_lr = -1
|
| 84 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 85 |
+
args.expand_list ="0.35"
|
| 86 |
+
args.depth_list = "1,2"
|
| 87 |
+
else:
|
| 88 |
+
args.n_epochs = 120
|
| 89 |
+
args.base_lr = 7.5e-3
|
| 90 |
+
args.warmup_epochs = 5
|
| 91 |
+
args.warmup_lr = -1
|
| 92 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 93 |
+
args.expand_list = "0.35"
|
| 94 |
+
args.depth_list = "0,1,2"
|
| 95 |
+
elif args.task == "expand":
|
| 96 |
+
if args.robust_mode :
|
| 97 |
+
args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
|
| 98 |
+
else:
|
| 99 |
+
args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
|
| 100 |
+
args.dynamic_batch_size = 4
|
| 101 |
+
if args.phase == 1:
|
| 102 |
+
args.n_epochs = 25
|
| 103 |
+
args.base_lr = 2.5e-3
|
| 104 |
+
args.warmup_epochs = 0
|
| 105 |
+
args.warmup_lr = -1
|
| 106 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 107 |
+
args.expand_list = "0.25,0.35"
|
| 108 |
+
args.depth_list = "0,1,2"
|
| 109 |
+
else:
|
| 110 |
+
args.n_epochs = 120
|
| 111 |
+
args.base_lr = 7.5e-3
|
| 112 |
+
args.warmup_epochs = 5
|
| 113 |
+
args.warmup_lr = -1
|
| 114 |
+
args.width_mult_list = "0.65,0.8,1.0"
|
| 115 |
+
args.expand_list = "0.2,0.25,0.35"
|
| 116 |
+
args.depth_list = "0,1,2"
|
| 117 |
+
else:
|
| 118 |
+
raise NotImplementedError
|
| 119 |
+
else:
|
| 120 |
+
args.width_mult_list = "1.0"
|
| 121 |
+
if args.task == "kernel":
|
| 122 |
+
if args.robust_mode:
|
| 123 |
+
args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
|
| 124 |
+
else:
|
| 125 |
+
args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
|
| 126 |
+
args.dynamic_batch_size = 1
|
| 127 |
+
args.n_epochs = 120
|
| 128 |
+
args.base_lr = 3e-2
|
| 129 |
+
args.warmup_epochs = 5
|
| 130 |
+
args.warmup_lr = -1
|
| 131 |
+
args.ks_list = "3,5,7"
|
| 132 |
+
args.expand_list = "6"
|
| 133 |
+
args.depth_list = "4"
|
| 134 |
+
elif args.task == "depth":
|
| 135 |
+
if args.robust_mode :
|
| 136 |
+
args.path = "exp/robust/WPS/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
|
| 137 |
+
else:
|
| 138 |
+
args.path = "exp/WPS/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
|
| 139 |
+
args.dynamic_batch_size = 2
|
| 140 |
+
if args.phase == 1:
|
| 141 |
+
args.n_epochs = 25
|
| 142 |
+
args.base_lr = 2.5e-3
|
| 143 |
+
args.warmup_epochs = 0
|
| 144 |
+
args.warmup_lr = -1
|
| 145 |
+
args.ks_list = "3,5,7"
|
| 146 |
+
args.expand_list = "6"
|
| 147 |
+
args.depth_list = "3,4"
|
| 148 |
+
else:
|
| 149 |
+
args.n_epochs = 120
|
| 150 |
+
args.base_lr = 7.5e-3
|
| 151 |
+
args.warmup_epochs = 5
|
| 152 |
+
args.warmup_lr = -1
|
| 153 |
+
args.ks_list = "3,5,7"
|
| 154 |
+
args.expand_list = "6"
|
| 155 |
+
args.depth_list = "2,3,4"
|
| 156 |
+
elif args.task == "expand":
|
| 157 |
+
if args.robust_mode:
|
| 158 |
+
args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase%d" % args.phase
|
| 159 |
+
else:
|
| 160 |
+
args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/kernel_depth2kernel_depth_width/phase%d" % args.phase
|
| 161 |
+
args.dynamic_batch_size = 4
|
| 162 |
+
if args.phase == 1:
|
| 163 |
+
args.n_epochs = 25
|
| 164 |
+
args.base_lr = 2.5e-3
|
| 165 |
+
args.warmup_epochs = 0
|
| 166 |
+
args.warmup_lr = -1
|
| 167 |
+
args.ks_list = "3,5,7"
|
| 168 |
+
args.expand_list = "4,6"
|
| 169 |
+
args.depth_list = "2,3,4"
|
| 170 |
+
else:
|
| 171 |
+
args.n_epochs = 120
|
| 172 |
+
args.base_lr = 7.5e-3
|
| 173 |
+
args.warmup_epochs = 5
|
| 174 |
+
args.warmup_lr = -1
|
| 175 |
+
args.ks_list = "3,5,7"
|
| 176 |
+
args.expand_list = "3,4,6"
|
| 177 |
+
args.depth_list = "2,3,4"
|
| 178 |
+
else:
|
| 179 |
+
raise NotImplementedError
|
| 180 |
+
args.manual_seed = 0
|
| 181 |
+
|
| 182 |
+
args.lr_schedule_type = "cosine"
|
| 183 |
+
|
| 184 |
+
args.base_batch_size = 64
|
| 185 |
+
args.valid_size = 64
|
| 186 |
+
|
| 187 |
+
args.opt_type = "sgd"
|
| 188 |
+
args.momentum = 0.9
|
| 189 |
+
args.no_nesterov = False
|
| 190 |
+
args.weight_decay = 3e-5
|
| 191 |
+
args.label_smoothing = 0.1
|
| 192 |
+
args.no_decay_keys = "bn#bias"
|
| 193 |
+
args.fp16_allreduce = False
|
| 194 |
+
|
| 195 |
+
args.model_init = "he_fout"
|
| 196 |
+
args.validation_frequency = 1
|
| 197 |
+
args.print_frequency = 10
|
| 198 |
+
|
| 199 |
+
args.n_worker = 8
|
| 200 |
+
args.resize_scale = 0.08
|
| 201 |
+
args.distort_color = "tf"
|
| 202 |
+
if args.dataset == "imagenet":
|
| 203 |
+
args.image_size = "128,160,192,224"
|
| 204 |
+
else:
|
| 205 |
+
args.image_size = "32"
|
| 206 |
+
args.continuous_size = True
|
| 207 |
+
args.not_sync_distributed_image_size = False
|
| 208 |
+
|
| 209 |
+
args.bn_momentum = 0.1
|
| 210 |
+
args.bn_eps = 1e-5
|
| 211 |
+
args.dropout = 0.1
|
| 212 |
+
args.base_stage_width = "google"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
args.dy_conv_scaling_mode = -1
|
| 216 |
+
args.independent_distributed_sampling = False
|
| 217 |
+
|
| 218 |
+
args.kd_ratio = 1.0
|
| 219 |
+
args.kd_type = "ce"
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
os.makedirs(args.path, exist_ok=True)
|
| 224 |
+
|
| 225 |
+
# Initialize Horovod
|
| 226 |
+
hvd.init()
|
| 227 |
+
# Pin GPU to be used to process local rank (one GPU per process)
|
| 228 |
+
torch.cuda.set_device(hvd.local_rank())
|
| 229 |
+
if args.robust_mode:
|
| 230 |
+
args.teacher_path = 'exp/robust/teacher/' + args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 231 |
+
else:
|
| 232 |
+
args.teacher_path = 'exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 233 |
+
num_gpus = hvd.size()
|
| 234 |
+
|
| 235 |
+
torch.manual_seed(args.manual_seed)
|
| 236 |
+
torch.cuda.manual_seed_all(args.manual_seed)
|
| 237 |
+
np.random.seed(args.manual_seed)
|
| 238 |
+
random.seed(args.manual_seed)
|
| 239 |
+
|
| 240 |
+
# image size
|
| 241 |
+
args.image_size = [int(img_size) for img_size in args.image_size.split(",")]
|
| 242 |
+
if len(args.image_size) == 1:
|
| 243 |
+
args.image_size = args.image_size[0]
|
| 244 |
+
MyRandomResizedCrop.CONTINUOUS = args.continuous_size
|
| 245 |
+
MyRandomResizedCrop.SYNC_DISTRIBUTED = not args.not_sync_distributed_image_size
|
| 246 |
+
|
| 247 |
+
# build run config from args
|
| 248 |
+
args.lr_schedule_param = None
|
| 249 |
+
args.opt_param = {
|
| 250 |
+
"momentum": args.momentum,
|
| 251 |
+
"nesterov": not args.no_nesterov,
|
| 252 |
+
}
|
| 253 |
+
args.init_lr = args.base_lr * num_gpus # linearly rescale the learning rate
|
| 254 |
+
if args.warmup_lr < 0:
|
| 255 |
+
args.warmup_lr = args.base_lr
|
| 256 |
+
args.train_batch_size = args.base_batch_size
|
| 257 |
+
args.test_batch_size = args.base_batch_size * 4
|
| 258 |
+
run_config = DistributedClassificationRunConfig(
|
| 259 |
+
**args.__dict__, num_replicas=num_gpus, rank=hvd.rank()
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# print run config information
|
| 263 |
+
if hvd.rank() == 0:
|
| 264 |
+
print("Run config:")
|
| 265 |
+
for k, v in run_config.config.items():
|
| 266 |
+
print("\t%s: %s" % (k, v))
|
| 267 |
+
|
| 268 |
+
if args.dy_conv_scaling_mode == -1:
|
| 269 |
+
args.dy_conv_scaling_mode = None
|
| 270 |
+
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode
|
| 271 |
+
|
| 272 |
+
# build net from args
|
| 273 |
+
args.width_mult_list = [
|
| 274 |
+
float(width_mult) for width_mult in args.width_mult_list.split(",")
|
| 275 |
+
]
|
| 276 |
+
args.ks_list = [int(ks) for ks in args.ks_list.split(",")]
|
| 277 |
+
if args.model_name == "ResNet50":
|
| 278 |
+
args.expand_list = [float(e) for e in args.expand_list.split(",")]
|
| 279 |
+
else:
|
| 280 |
+
args.expand_list = [int(e) for e in args.expand_list.split(",")]
|
| 281 |
+
args.depth_list = [int(d) for d in args.depth_list.split(",")]
|
| 282 |
+
|
| 283 |
+
args.width_mult_list = (
|
| 284 |
+
args.width_mult_list[0]
|
| 285 |
+
if len(args.width_mult_list) == 1
|
| 286 |
+
else args.width_mult_list
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if args.model_name == "ResNet50":
|
| 290 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 291 |
+
net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes,
|
| 292 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 293 |
+
dropout_rate=args.dropout,
|
| 294 |
+
depth_list=args.depth_list,
|
| 295 |
+
expand_ratio_list=args.expand_list,
|
| 296 |
+
width_mult_list=args.width_mult_list,)
|
| 297 |
+
else:
|
| 298 |
+
net = DYNResNets( n_classes=run_config.data_provider.n_classes,
|
| 299 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 300 |
+
dropout_rate=args.dropout,
|
| 301 |
+
depth_list=args.depth_list,
|
| 302 |
+
expand_ratio_list=args.expand_list,
|
| 303 |
+
width_mult_list=args.width_mult_list,)
|
| 304 |
+
elif args.model_name == "MBV3":
|
| 305 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 306 |
+
net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 307 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 308 |
+
else:
|
| 309 |
+
net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 310 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 311 |
+
elif args.model_name == "ProxylessNASNet":
|
| 312 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 313 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 314 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 315 |
+
else:
|
| 316 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 317 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
|
| 318 |
+
elif args.model_name == "MBV2":
|
| 319 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 320 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 321 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
|
| 322 |
+
else:
|
| 323 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 324 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
|
| 325 |
+
else:
|
| 326 |
+
raise NotImplementedError
|
| 327 |
+
# teacher model
|
| 328 |
+
if args.kd_ratio > 0:
|
| 329 |
+
|
| 330 |
+
if args.model_name =="ResNet50":
|
| 331 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 332 |
+
args.teacher_model = DYNResNets_Cifar(
|
| 333 |
+
n_classes=run_config.data_provider.n_classes,
|
| 334 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 335 |
+
dropout_rate=args.dropout,
|
| 336 |
+
depth_list=[2],
|
| 337 |
+
expand_ratio_list=[0.35],
|
| 338 |
+
width_mult_list=[1.0],
|
| 339 |
+
)
|
| 340 |
+
else:
|
| 341 |
+
args.teacher_model = DYNResNets(
|
| 342 |
+
n_classes=run_config.data_provider.n_classes,
|
| 343 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 344 |
+
dropout_rate=args.dropout,
|
| 345 |
+
depth_list=[2],
|
| 346 |
+
expand_ratio_list=[0.35],
|
| 347 |
+
width_mult_list=[1.0],
|
| 348 |
+
)
|
| 349 |
+
elif args.model_name =="MBV3":
|
| 350 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 351 |
+
args.teacher_model = DYNMobileNetV3_Cifar(
|
| 352 |
+
n_classes=run_config.data_provider.n_classes,
|
| 353 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 354 |
+
dropout_rate=0,
|
| 355 |
+
width_mult=1.0,
|
| 356 |
+
ks_list=[7],
|
| 357 |
+
expand_ratio_list=[6],
|
| 358 |
+
depth_list=[4]
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
args.teacher_model = DYNMobileNetV3(
|
| 362 |
+
n_classes=run_config.data_provider.n_classes,
|
| 363 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 364 |
+
dropout_rate=0,
|
| 365 |
+
width_mult=1.0,
|
| 366 |
+
ks_list=[7],
|
| 367 |
+
expand_ratio_list=[6],
|
| 368 |
+
depth_list=[4]
|
| 369 |
+
)
|
| 370 |
+
elif args.model_name == "ProxylessNASNet":
|
| 371 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 372 |
+
args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 373 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 374 |
+
dropout_rate=0,
|
| 375 |
+
width_mult=1.0,
|
| 376 |
+
ks_list=[7],
|
| 377 |
+
expand_ratio_list=[6],
|
| 378 |
+
depth_list=[4])
|
| 379 |
+
else:
|
| 380 |
+
args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
|
| 381 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 382 |
+
dropout_rate=0,
|
| 383 |
+
width_mult=1.0,
|
| 384 |
+
ks_list=[7],
|
| 385 |
+
expand_ratio_list=[6],
|
| 386 |
+
depth_list=[4])
|
| 387 |
+
elif args.model_name == "MBV2":
|
| 388 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 389 |
+
args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
|
| 390 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 391 |
+
dropout_rate=0,
|
| 392 |
+
width_mult=1.0,
|
| 393 |
+
ks_list=[7],
|
| 394 |
+
expand_ratio_list=[6],
|
| 395 |
+
depth_list=[4],base_stage_width=args.base_stage_width)
|
| 396 |
+
else:
|
| 397 |
+
args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
|
| 398 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 399 |
+
dropout_rate=0,
|
| 400 |
+
width_mult=1.0,
|
| 401 |
+
ks_list=[7],
|
| 402 |
+
expand_ratio_list=[6],
|
| 403 |
+
depth_list=[4],base_stage_width=args.base_stage_width)
|
| 404 |
+
|
| 405 |
+
args.teacher_model.cuda()
|
| 406 |
+
|
| 407 |
+
""" Distributed RunManager """
|
| 408 |
+
# Horovod: (optional) compression algorithm.
|
| 409 |
+
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
|
| 410 |
+
distributed_run_manager = DistributedRunManager(
|
| 411 |
+
args.path,
|
| 412 |
+
net,
|
| 413 |
+
run_config,
|
| 414 |
+
compression,
|
| 415 |
+
backward_steps=args.dynamic_batch_size,
|
| 416 |
+
is_root=(hvd.rank() == 0),
|
| 417 |
+
)
|
| 418 |
+
distributed_run_manager.save_config()
|
| 419 |
+
# hvd broadcast
|
| 420 |
+
distributed_run_manager.broadcast()
|
| 421 |
+
|
| 422 |
+
# load teacher net weights
|
| 423 |
+
if args.kd_ratio > 0:
|
| 424 |
+
load_models(
|
| 425 |
+
distributed_run_manager, args.teacher_model, model_path=args.teacher_path
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# training
|
| 429 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 430 |
+
validate,
|
| 431 |
+
train,
|
| 432 |
+
)
|
| 433 |
+
if args.model_name =="ResNet50":
|
| 434 |
+
validate_func_dict = {
|
| 435 |
+
"image_size_list": {224 if args.dataset == "imagenet" else 32}
|
| 436 |
+
if isinstance(args.image_size, int)
|
| 437 |
+
else sorted({160, 224}),
|
| 438 |
+
"width_mult_list": sorted({min(args.width_mult_list), max(args.width_mult_list)}),
|
| 439 |
+
"expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
|
| 440 |
+
"depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
|
| 441 |
+
}
|
| 442 |
+
else:
|
| 443 |
+
validate_func_dict = {
|
| 444 |
+
"image_size_list": {224 if args.dataset == "imagenet" else 32}
|
| 445 |
+
if isinstance(args.image_size, int)
|
| 446 |
+
else sorted({160, 224}),
|
| 447 |
+
"width_mult_list": [1.0],
|
| 448 |
+
"ks_list": sorted({min(args.ks_list), max(args.ks_list)}),
|
| 449 |
+
"expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
|
| 450 |
+
"depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
if args.task == "width":
|
| 454 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 455 |
+
train_elastic_width_mult,
|
| 456 |
+
)
|
| 457 |
+
if distributed_run_manager.start_epoch == 0:
|
| 458 |
+
if args.robust_mode:
|
| 459 |
+
args.dyn_checkpoint_path ='exp/robust/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 460 |
+
else:
|
| 461 |
+
args.dyn_checkpoint_path ='exp/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 462 |
+
load_models(
|
| 463 |
+
distributed_run_manager,
|
| 464 |
+
distributed_run_manager.net,
|
| 465 |
+
args.dyn_checkpoint_path,
|
| 466 |
+
)
|
| 467 |
+
distributed_run_manager.write_log(
|
| 468 |
+
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
|
| 469 |
+
% validate(distributed_run_manager, is_test=True, **validate_func_dict),
|
| 470 |
+
"valid",
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
assert args.resume
|
| 474 |
+
train(distributed_run_manager,args,lambda _run_manager, epoch, is_test: validate(
|
| 475 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 476 |
+
),)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
elif args.task == "kernel":
|
| 481 |
+
validate_func_dict["ks_list"] = sorted(args.ks_list)
|
| 482 |
+
if distributed_run_manager.start_epoch == 0:
|
| 483 |
+
if args.robust_mode:
|
| 484 |
+
args.dyn_checkpoint_path ='exp/robust/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 485 |
+
else:
|
| 486 |
+
args.dyn_checkpoint_path ='exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
|
| 487 |
+
load_models(
|
| 488 |
+
distributed_run_manager,
|
| 489 |
+
distributed_run_manager.net,
|
| 490 |
+
args.dyn_checkpoint_path,
|
| 491 |
+
)
|
| 492 |
+
distributed_run_manager.write_log(
|
| 493 |
+
"%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
|
| 494 |
+
% validate(distributed_run_manager, is_test=True, **validate_func_dict),
|
| 495 |
+
"valid",
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
assert args.resume
|
| 499 |
+
train(
|
| 500 |
+
distributed_run_manager,
|
| 501 |
+
args,
|
| 502 |
+
lambda _run_manager, epoch, is_test: validate(
|
| 503 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 504 |
+
),
|
| 505 |
+
)
|
| 506 |
+
elif args.task == "depth":
|
| 507 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 508 |
+
train_elastic_depth,
|
| 509 |
+
)
|
| 510 |
+
if args.robust_mode:
|
| 511 |
+
if args.model_name =="ResNet50":
|
| 512 |
+
if args.phase == 1:
|
| 513 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
|
| 514 |
+
else:
|
| 515 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 516 |
+
else:
|
| 517 |
+
if args.phase == 1:
|
| 518 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
|
| 519 |
+
else:
|
| 520 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 521 |
+
else :
|
| 522 |
+
if args.model_name =="ResNet50":
|
| 523 |
+
if args.phase == 1:
|
| 524 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
|
| 525 |
+
else:
|
| 526 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 527 |
+
else:
|
| 528 |
+
if args.phase == 1:
|
| 529 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
|
| 530 |
+
else:
|
| 531 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
|
| 532 |
+
train(
|
| 533 |
+
distributed_run_manager,
|
| 534 |
+
args,
|
| 535 |
+
lambda _run_manager, epoch, is_test: validate(
|
| 536 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 537 |
+
),)
|
| 538 |
+
elif args.task == "expand":
|
| 539 |
+
from proard.classification.elastic_nn.training.progressive_shrinking import (
|
| 540 |
+
train_elastic_expand,
|
| 541 |
+
)
|
| 542 |
+
if args.robust_mode :
|
| 543 |
+
if args.model_name =="ResNet50":
|
| 544 |
+
if args.phase == 1:
|
| 545 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 546 |
+
else:
|
| 547 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 548 |
+
else:
|
| 549 |
+
if args.phase == 1:
|
| 550 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 551 |
+
else:
|
| 552 |
+
args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 553 |
+
else:
|
| 554 |
+
if args.model_name =="ResNet50":
|
| 555 |
+
if args.phase == 1:
|
| 556 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 557 |
+
else:
|
| 558 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 559 |
+
else:
|
| 560 |
+
if args.phase == 1:
|
| 561 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
|
| 562 |
+
else:
|
| 563 |
+
args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
|
| 564 |
+
|
| 565 |
+
train(
|
| 566 |
+
distributed_run_manager,
|
| 567 |
+
args,
|
| 568 |
+
lambda _run_manager, epoch, is_test: validate(
|
| 569 |
+
_run_manager, epoch, is_test, **validate_func_dict
|
| 570 |
+
),)
|
| 571 |
+
else:
|
| 572 |
+
raise NotImplementedError
|
train_teacher_net.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
| 2 |
+
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
| 3 |
+
# International Conference on Learning Representations (ICLR), 2020.
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
# using for distributed training
|
| 10 |
+
import horovod.torch as hvd
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from proard.classification.elastic_nn.modules.dynamic_op import (
|
| 14 |
+
DynamicSeparableConv2d,
|
| 15 |
+
)
|
| 16 |
+
from proard.classification.elastic_nn.networks import DYNResNets,DYNMobileNetV3,DYNProxylessNASNets,DYNMobileNetV3_Cifar,DYNResNets_Cifar,DYNProxylessNASNets_Cifar
|
| 17 |
+
from proard.classification.run_manager import DistributedClassificationRunConfig
|
| 18 |
+
from proard.classification.networks import WideResNet
|
| 19 |
+
from proard.classification.run_manager import DistributedRunManager
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet","WideResNet","MBV2"])
|
| 24 |
+
parser.add_argument("--teacher_model_name", type=str, default="WideResNet", choices=["WideResNet"])
|
| 25 |
+
parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar10", "cifar100", "imagenet"])
|
| 26 |
+
parser.add_argument("--robust_mode", type=bool, default=True)
|
| 27 |
+
parser.add_argument("--epsilon", type=float, default=0.031)
|
| 28 |
+
parser.add_argument("--num_steps", type=int, default=10)
|
| 29 |
+
parser.add_argument("--step_size", type=float, default=0.0078)
|
| 30 |
+
parser.add_argument("--clip_min", type=int, default=0)
|
| 31 |
+
parser.add_argument("--clip_max", type=int, default=1)
|
| 32 |
+
parser.add_argument("--const_init", type=bool, default=False)
|
| 33 |
+
parser.add_argument("--beta", type=float, default=6.0)
|
| 34 |
+
parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"])
|
| 35 |
+
parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
|
| 36 |
+
parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"])
|
| 37 |
+
parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"])
|
| 38 |
+
parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce'])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
if args.robust_mode:
|
| 43 |
+
args.path = 'exp/robust/teacher/' + args.dataset + "/" + args.model_name + '/' + args.train_criterion
|
| 44 |
+
else:
|
| 45 |
+
args.path = 'exp/teacher/' + args.dataset + "/" + args.model_name
|
| 46 |
+
args.n_epochs = 120
|
| 47 |
+
args.base_lr = 0.1
|
| 48 |
+
args.warmup_epochs = 5
|
| 49 |
+
args.warmup_lr = -1
|
| 50 |
+
args.manual_seed = 0
|
| 51 |
+
args.lr_schedule_type = "cosine"
|
| 52 |
+
args.base_batch_size = 128
|
| 53 |
+
args.valid_size = None
|
| 54 |
+
args.opt_type = "sgd"
|
| 55 |
+
args.momentum = 0.9
|
| 56 |
+
args.no_nesterov = False
|
| 57 |
+
args.weight_decay = 2e-4
|
| 58 |
+
args.label_smoothing = 0.0
|
| 59 |
+
args.no_decay_keys = "bn#bias"
|
| 60 |
+
args.fp16_allreduce = False
|
| 61 |
+
args.model_init = "he_fout"
|
| 62 |
+
args.validation_frequency = 1
|
| 63 |
+
args.print_frequency = 10
|
| 64 |
+
args.n_worker = 32
|
| 65 |
+
if args.dataset =="imagenet":
|
| 66 |
+
args.image_size = "224"
|
| 67 |
+
else:
|
| 68 |
+
args.image_size = "32"
|
| 69 |
+
args.continuous_size = True
|
| 70 |
+
args.not_sync_distributed_image_size = False
|
| 71 |
+
args.bn_momentum = 0.1
|
| 72 |
+
args.bn_eps = 1e-5
|
| 73 |
+
args.dropout = 0.0
|
| 74 |
+
args.base_stage_width = "google"
|
| 75 |
+
###### Parameters for MBV3, ProxylessNet, and MBV2
|
| 76 |
+
if args.model_name != "ResNet50":
|
| 77 |
+
args.ks_list = '7'
|
| 78 |
+
args.expand_list = '6'
|
| 79 |
+
args.depth_list = '4'
|
| 80 |
+
args.width_mult_list = "1.0"
|
| 81 |
+
else:
|
| 82 |
+
###### Parameters for ResNet50
|
| 83 |
+
args.ks_list = "3"
|
| 84 |
+
args.expand_list = "0.35"
|
| 85 |
+
args.depth_list = "2"
|
| 86 |
+
args.width_mult_list = "1.0"
|
| 87 |
+
########################################
|
| 88 |
+
args.dy_conv_scaling_mode = 1
|
| 89 |
+
args.independent_distributed_sampling = False
|
| 90 |
+
args.kd_ratio = 0.0
|
| 91 |
+
args.kd_type = "ce"
|
| 92 |
+
args.dynamic_batch_size = 1
|
| 93 |
+
args.num_gpus = 4
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
os.makedirs(args.path, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
# Initialize Horovod
|
| 98 |
+
hvd.init()
|
| 99 |
+
# Pin GPU to be used to process local rank (one GPU per process)
|
| 100 |
+
torch.cuda.set_device(hvd.local_rank())
|
| 101 |
+
|
| 102 |
+
num_gpus = hvd.size()
|
| 103 |
+
torch.manual_seed(args.manual_seed)
|
| 104 |
+
torch.cuda.manual_seed_all(args.manual_seed)
|
| 105 |
+
np.random.seed(args.manual_seed)
|
| 106 |
+
random.seed(args.manual_seed)
|
| 107 |
+
|
| 108 |
+
# image size
|
| 109 |
+
args.image_size = [int(img_size) for img_size in args.image_size.split(",")]
|
| 110 |
+
if len(args.image_size) == 1:
|
| 111 |
+
args.image_size = args.image_size[0]
|
| 112 |
+
|
| 113 |
+
# build run config from args
|
| 114 |
+
args.lr_schedule_param = None
|
| 115 |
+
args.opt_param = {
|
| 116 |
+
"momentum": args.momentum,
|
| 117 |
+
"nesterov": not args.no_nesterov,
|
| 118 |
+
}
|
| 119 |
+
args.init_lr = args.base_lr * num_gpus # linearly rescale the learning rate
|
| 120 |
+
if args.warmup_lr < 0:
|
| 121 |
+
args.warmup_lr = args.base_lr
|
| 122 |
+
args.train_batch_size = args.base_batch_size
|
| 123 |
+
args.test_batch_size = args.base_batch_size
|
| 124 |
+
print(args.__dict__)
|
| 125 |
+
run_config = DistributedClassificationRunConfig(
|
| 126 |
+
**args.__dict__,num_replicas=num_gpus, rank=hvd.rank()
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# print run config information
|
| 130 |
+
if hvd.rank() == 0:
|
| 131 |
+
print("Run config:")
|
| 132 |
+
for k, v in run_config.config.items():
|
| 133 |
+
print("\t%s: %s" % (k, v))
|
| 134 |
+
|
| 135 |
+
if args.dy_conv_scaling_mode == -1:
|
| 136 |
+
args.dy_conv_scaling_mode = None
|
| 137 |
+
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode
|
| 138 |
+
|
| 139 |
+
# build net from args
|
| 140 |
+
args.width_mult_list = [
|
| 141 |
+
float(width_mult) for width_mult in args.width_mult_list.split(",")
|
| 142 |
+
]
|
| 143 |
+
args.ks_list = [int(ks) for ks in args.ks_list.split(",")]
|
| 144 |
+
args.expand_list = [float(e) for e in args.expand_list.split(",")]
|
| 145 |
+
args.depth_list = [int(d) for d in args.depth_list.split(",")]
|
| 146 |
+
|
| 147 |
+
args.width_mult_list = (
|
| 148 |
+
args.width_mult_list[0]
|
| 149 |
+
if len(args.width_mult_list) == 1
|
| 150 |
+
else args.width_mult_list
|
| 151 |
+
)
|
| 152 |
+
if args.model_name == "ResNet50":
|
| 153 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 154 |
+
# net = ResNet50_Cifar(n_classes=run_config.data_provider.n_classes)
|
| 155 |
+
net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes,
|
| 156 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 157 |
+
dropout_rate=args.dropout,
|
| 158 |
+
depth_list=args.depth_list,
|
| 159 |
+
expand_ratio_list=args.expand_list,
|
| 160 |
+
width_mult_list=args.width_mult_list,)
|
| 161 |
+
else:
|
| 162 |
+
net = DYNResNets( n_classes=run_config.data_provider.n_classes,
|
| 163 |
+
bn_param=(args.bn_momentum, args.bn_eps),
|
| 164 |
+
dropout_rate=args.dropout,
|
| 165 |
+
depth_list=args.depth_list,
|
| 166 |
+
expand_ratio_list=args.expand_list,
|
| 167 |
+
width_mult_list=args.width_mult_list,)
|
| 168 |
+
elif args.model_name == "MBV3":
|
| 169 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 170 |
+
net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 171 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
|
| 172 |
+
else:
|
| 173 |
+
net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 174 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
|
| 175 |
+
elif args.model_name == "ProxylessNASNet":
|
| 176 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 177 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 178 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
|
| 179 |
+
else:
|
| 180 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
|
| 181 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
|
| 182 |
+
|
| 183 |
+
elif args.model_name == "MBV2":
|
| 184 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 185 |
+
net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),base_stage_width=args.base_stage_width,
|
| 186 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
|
| 187 |
+
else:
|
| 188 |
+
net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),base_stage_width=args.base_stage_width,
|
| 189 |
+
dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
|
| 190 |
+
else:
|
| 191 |
+
raise NotImplementedError
|
| 192 |
+
if args.teacher_model_name == "WideResNet":
|
| 193 |
+
if args.dataset == "cifar10" or args.dataset == "cifar100":
|
| 194 |
+
net = WideResNet(num_classes=run_config.data_provider.n_classes)
|
| 195 |
+
else:
|
| 196 |
+
raise NotImplementedError
|
| 197 |
+
else:
|
| 198 |
+
raise NotImplementedError
|
| 199 |
+
args.teacher_model = None #'exp/teacher/' + args.dataset + "/" + "WideResNet"
|
| 200 |
+
|
| 201 |
+
""" Distributed RunManager """
|
| 202 |
+
#Horovod: (optional) compression algorithm.
|
| 203 |
+
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
|
| 204 |
+
distributed_run_manager = DistributedRunManager(
|
| 205 |
+
args.path,
|
| 206 |
+
net,
|
| 207 |
+
run_config,
|
| 208 |
+
compression,
|
| 209 |
+
backward_steps=args.dynamic_batch_size,
|
| 210 |
+
is_root=(hvd.rank() == 0),
|
| 211 |
+
)
|
| 212 |
+
distributed_run_manager.save_config()
|
| 213 |
+
distributed_run_manager.broadcast()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
distributed_run_manager.train(args)
|