smi08 commited on
Commit
f6a2150
·
verified ·
1 Parent(s): d008243

Upload 10 files

Browse files
README.md CHANGED
@@ -1,3 +1,185 @@
1
- ---
2
- license: unknown
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">
2
+ <img src="images/ProARD_logo.png" width="500"/>
3
+ <br/>
4
+ PROARD: PROGRESSIVE ADVERSARIAL ROBUSTNESS DISTILLATION: PROVIDE WIDE RANGE OF ROBUST STUDENTS
5
+ </br>
6
+ </h1>
7
+ <p align="center">
8
+ <a href="#background">Background</a> •
9
+ <a href="#usage">Usage</a> •
10
+ <a href="#code">Code</a> •
11
+ <a href="#citation">Citation</a> •
12
+ </p>
13
+
14
+ ## Background
15
+ Progressive Adversarial Robustness Distillation (ProARD), enabling the efficient
16
+ one-time training of a dynamic network that supports a diverse range of accurate and robust student
17
+ networks without requiring retraining. ProARD makes a dynamic deep neural network based on
18
+ dynamic layers by encompassing variations in width, depth, and expansion in each design stage to
19
+ support a wide range of architectures.
20
+
21
+ <h1 align="center">
22
+ <img src="images/ProARD.png" width="1000"/>
23
+ </h1>
24
+
25
+
26
+
27
+ ## Usage
28
+ ```
29
+ git clone https://github.com/hamidmousavi0/ProARD.git
30
+ ```
31
+ ## Code Structure
32
+ ```
33
+ - attacks/ # Different Adversarial attack methods (PGD, AutoAttack, FGSM, DeepFool, etc. ([Refrence](https://github.com/imrahulr/hat.git)))
34
+ - proard/
35
+ - classification/
36
+ - data_provider/ # The dataset and dataloader definitions for Cifar-10, Cifar-100, and ImageNet.
37
+ - elastic_nn/
38
+ - modules/ # The deficnition of dynamic layers
39
+ - networks/ # The deficnition of dynamic networks
40
+ - training/ # Progressive training
41
+ -networks/ # The original networks
42
+ -run_anager/ # The Configs and distributed training
43
+ - nas
44
+ - accuracy_predictor/ # The accuracy and robustness predictor
45
+ - efficiency_predictor/ # The efficiency predictor
46
+ - search_algorithm/ # The Multi-Objective Search Engine
47
+ - utils/ # Utility functions
48
+ - model_zoo.py # All the models for evaluation
49
+ - create_acc_rob_pred_dataset.py # Create dataset to train the accuracy-robustness predictor.
50
+ - create_acc_rob_pred.py # make the predictor model.
51
+ - eval_ofa_net.py # Eval the sub-nets
52
+ - search_best.py # Search the best sub-net
53
+ - train_ofa_net_WPS.py # train the dynamic network without progressive training.
54
+ - train_ofa_net.py # Train the dynamic network with progressive training.
55
+ - train_teacher_net.py # Train teacher network for Robust knoweldge distillation.
56
+
57
+ ```
58
+ ### Installing
59
+
60
+ **From Source**
61
+
62
+ Download this repository into your project folder.
63
+
64
+ ### Details of the usage
65
+
66
+ ## Evaluation
67
+
68
+ ```
69
+ python eval_ofa_net.py --path path of dataset --net Dynamic net name (ResNet50, MBV3)
70
+ --dataset (cifar10, cifar100) --robust_mode (True, False)
71
+ --WPS (True, False) --attack ('fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce')
72
+ ```
73
+
74
+
75
+ ## Training
76
+
77
+ ### Step-0: Train Teacher Net
78
+
79
+ ```
80
+ horovodrun -np 4 python train_teacher_net.py --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
81
+ --robust_mode (True, False) --epsilon 0.031 --num_steps 10
82
+ --step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
83
+ --attack_type 'linf-pgd'
84
+ ```
85
+
86
+ ### Step-1: Dynamic Width/Kernel training
87
+
88
+
89
+ ```
90
+ horovodrun -np 4 python train_ofa_net.py --task 'width' or 'kernel' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
91
+ --robust_mode (True, False) --epsilon 0.031 --num_steps 10
92
+ --step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
93
+ --attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 1
94
+ ```
95
+
96
+ ### Step-2: Dynamic Width/Kernel and depth training
97
+
98
+ ##### Phase-1
99
+ ```
100
+ horovodrun -np 4 python train_ofa_net.py --task 'depth' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
101
+ --robust_mode (True, False) --epsilon 0.031 --num_steps 10
102
+ --step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
103
+ --attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 1
104
+ ```
105
+ ##### Phase-2
106
+ ```
107
+ horovodrun -np 4 python train_ofa_net.py --task 'depth' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
108
+ --robust_mode (True, False) --epsilon 0.031 --num_steps 10
109
+ --step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
110
+ --attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 2
111
+ ```
112
+
113
+ ### Step-3: Dynamic Width/Kernel, depth, and expand training
114
+
115
+
116
+ ##### Phase-1
117
+ ```
118
+ horovodrun -np 4 python train_ofa_net.py --task 'expand' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
119
+ --robust_mode (True, False) --epsilon 0.031 --num_steps 10
120
+ --step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
121
+ --attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 1
122
+ ```
123
+ ##### Phase-2
124
+ ```
125
+ horovodrun -np 4 python train_ofa_net.py --task 'expand' --model_name ("ResNet50", "MBV3") --dataset (cifar10, cifar100)
126
+ --robust_mode (True, False) --epsilon 0.031 --num_steps 10
127
+ --step_size 0.0078 --distance 'l-inf' --train_criterion 'trades'
128
+ --attack_type 'linf-pgd' --kd_criterion 'rslad' --phase 2
129
+ ```
130
+
131
+
132
+
133
+
134
+
135
+ <!--
136
+
137
+ * **ProAct** (the proposed algorithm) ([paper](https://arxiv.org/abs/2406.06313) and ([code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/proact.py)).
138
+ * **FitAct** ([paper](https://arxiv.org/pdf/2112.13544) and [code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/fitact.py)).
139
+ * **FtClipAct** ([paper](https://arxiv.org/pdf/1912.00941) and [code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/ftclip.py)).
140
+ * **Ranger** ([paper](https://arxiv.org/pdf/2003.13874) and [code](https://github.com/hamidmousavi0/reliable-relu-toolbox/tree/master/rrelu/search_bound/ranger.py)).
141
+ -->
142
+
143
+
144
+
145
+
146
+ <!-- Use the following notebook to learn the main steps of the tool.
147
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/hamidmousavi0/reliable-relu-toolbox/blob/master/RReLU.ipynb)-->
148
+
149
+
150
+ ## To-do list
151
+ - [ ] Add object detection Task
152
+ - [ ] Add Transformers architectures
153
+
154
+ <!--
155
+ ### Run search in the command line
156
+
157
+ When you download this repository into your project folder.
158
+ ```
159
+ torchrun --nproc_per_node=2 search.py --dataset cifar10 --data_path "./dataset/cifar10" --batch_size 128 --model "resnet20" --n_worker 32 \
160
+ --name_relu_bound "zero" --name_serach_bound "ranger" --bounds_type "layer" --bitflip "fixed" --image_size 32 --pretrained_model
161
+ ```
162
+ -->
163
+ ## Citation
164
+
165
+ View the [published paper(preprint), Accepted in IJCNN 2025](https://www.arxiv.org/pdf/2506.07666).
166
+ <!--
167
+ ```
168
+ @article{mousavi2024proact,
169
+ title={ProAct: Progressive Training for Hybrid Clipped Activation Function to Enhance Resilience of DNNs},
170
+ author={Mousavi, Seyedhamidreza and Ahmadilivani, Mohammad Hasan and Raik, Jaan and Jenihhin, Maksim and Daneshtalab, Masoud},
171
+ journal={arXiv preprint arXiv:2406.06313},
172
+ year={2024}
173
+ }
174
+ ```
175
+ -->
176
+ ## Acknowledgment
177
+
178
+ We acknowledge the National Academic Infrastructure for Supercomputing in Sweden (NAISS), partially funded by the Swedish Research Council through grant agreement no
179
+
180
+ ## Contributors
181
+ Some of the code in this repository is based on the following amazing works:
182
+
183
+ [Once-For-All](https://github.com/mit-han-lab/once-for-all.git)
184
+
185
+ [Hat](https://github.com/imrahulr/hat.git)
create_acc_rob_pred.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torch.nn as nn
5
+ from tqdm.auto import tqdm
6
+ from torch.utils.data import DataLoader
7
+ import torch.nn as nn
8
+ import torch
9
+ from torch import nn
10
+ from torch.optim import *
11
+ from torch.optim.lr_scheduler import *
12
+ from torch.utils.data import DataLoader
13
+ from torchprofile import profile_macs
14
+ from torchvision.datasets import *
15
+ from torchvision.transforms import *
16
+ from proard.classification.data_providers.imagenet import ImagenetDataProvider
17
+ from proard.classification.run_manager import DistributedClassificationRunConfig, DistributedRunManager
18
+ from proard.model_zoo import DYN_net
19
+ from proard.nas.accuracy_predictor import AccuracyDataset,AccuracyPredictor,ResNetArchEncoder,RobustnessPredictor,MobileNetArchEncoder,AccuracyRobustnessDataset,Accuracy_Robustness_Predictor
20
+ parser = argparse.ArgumentParser()
21
+
22
+
23
+ def RMSELoss(yhat,y):
24
+ return torch.sqrt(torch.mean((yhat-y)**2))
25
+ def train(
26
+ model: nn.Module,
27
+ dataloader: DataLoader,
28
+ criterion: nn.Module,
29
+ optimizer: Optimizer,
30
+ callbacks = None,
31
+ epochs = 10,
32
+ save_path = None
33
+ ) -> None:
34
+ model.cuda()
35
+ model.train()
36
+ for epoch in range(epochs):
37
+ print(epoch)
38
+ for inputs, targets_acc, targets_rob in tqdm(dataloader, desc='train', leave=False):
39
+ inputs = inputs.float().cuda()
40
+ targets_acc = targets_acc.cuda()
41
+ targets_rob = targets_rob.cuda()
42
+
43
+ # Reset the gradients (from the last iteration)
44
+ optimizer.zero_grad()
45
+
46
+ # Forward inference
47
+ outputs = model(inputs)
48
+ loss = criterion(outputs[:,0], targets_acc) + criterion(outputs[:,1], targets_rob)
49
+
50
+ # Backward propagation
51
+ loss.backward()
52
+
53
+ # Update optimizer and LR scheduler
54
+ optimizer.step()
55
+ # scheduler.step(epoch)
56
+
57
+ if callbacks is not None:
58
+ for callback in callbacks:
59
+ callback()
60
+ torch.save(model.state_dict(), save_path)
61
+ return model
62
+
63
+ @torch.inference_mode()
64
+ def evaluate(
65
+ model: nn.Module,
66
+ dataloader: DataLoader,
67
+ ) -> float:
68
+ model.eval()
69
+
70
+ for inputs, targets_acc, targets_rob in tqdm(dataloader, desc="eval", leave=False):
71
+ # Move the data from CPU to GPU
72
+ inputs = inputs.cuda()
73
+
74
+ targets_acc = targets_acc.cuda()
75
+ targets_rob = targets_rob.cuda()
76
+
77
+
78
+ # Inference
79
+ outputs = model(inputs)
80
+
81
+ # Convert logits to class indices
82
+ print(RMSELoss(outputs[:,0],targets_acc),RMSELoss(outputs[:,1],targets_rob))
83
+ return RMSELoss(outputs[:,0],targets_acc) + RMSELoss(outputs[:,1],targets_rob)
84
+
85
+
86
+ def get_model_flops(model, inputs):
87
+ num_macs = profile_macs(model, inputs)
88
+ return num_macs
89
+
90
+
91
+ def get_model_size(model: nn.Module, data_width=32):
92
+ """
93
+ calculate the model size in bits
94
+ :param data_width: #bits per element
95
+ """
96
+ num_elements = 0
97
+ for param in model.parameters():
98
+ num_elements += param.numel()
99
+ return num_elements * data_width
100
+
101
+
102
+
103
+
104
+
105
+ parser.add_argument(
106
+ "-p", "--path", help="The path of cifar10", type=str, default="/dataset/cifar10"
107
+ )
108
+ parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
109
+ parser.add_argument(
110
+ "-b",
111
+ "--batch_size",
112
+ help="The batch on every device for validation",
113
+ type=int,
114
+ default=32,
115
+ )
116
+ parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
117
+ parser.add_argument(
118
+ "-n",
119
+ "--net",
120
+ metavar="DYNNET",
121
+ default="ResNet50",
122
+ choices=[
123
+ "ResNet50",
124
+ "MBV3",
125
+ "ProxylessNASNet",
126
+ ],
127
+ help="Dyanmic networks",
128
+ )
129
+ parser.add_argument(
130
+ "--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
131
+ )
132
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
133
+ parser.add_argument(
134
+ "--robust_mode", type=bool, default=True
135
+ )
136
+ args = parser.parse_args()
137
+ if args.net == "ResNet50":
138
+ arch = ResNetArchEncoder(image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[0,1,2],expand_list=[0.2,0.25,0.35],width_mult_list=[0.65,0.8,1.0])
139
+ else:
140
+ arch = MobileNetArchEncoder (image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[2,3,4],expand_list=[3,4,6],ks_list=[3,5,7])
141
+ print(arch)
142
+ acc_data = AccuracyRobustnessDataset("./acc_rob_data_{}_{}_{}".format(args.dataset,args.net,args.train_criterion))
143
+ train_loader, valid_loader, base_acc ,base_rob = acc_data.build_acc_data_loader(arch)
144
+ acc_pred_network = Accuracy_Robustness_Predictor(arch_encoder=arch,base_acc_val=None)
145
+ # optimizer_ = torch.optim.Adam(acc_pred_network.parameters(),lr=1e-3,weight_decay=1e-4)
146
+ # criterion = nn.MSELoss()
147
+ # acc_pred_network = train(acc_pred_network,train_loader,criterion,optimizer_,callbacks=None, epochs=50,save_path ="./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion).format(args.dataset))
148
+ acc_pred_network.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
149
+ print(evaluate(acc_pred_network,valid_loader))
150
+
151
+
152
+
153
+ # import numpy as np
154
+ # accs=[]
155
+ # robs=[]
156
+ # pred_accs=[]
157
+ # pred_robs=[]
158
+ # for x,acc,rob, in valid_loader:
159
+ # for ac in acc:
160
+ # accs.append(ac.item()*100)
161
+ # for ro in rob:
162
+ # robs.append(ro.item()*100)
163
+
164
+ # for x,acc,rob, in valid_loader:
165
+ # for arch in x:
166
+ # acc ,rob = acc_pred_network(arch.cuda())
167
+ # pred_accs.append(acc.item()*100)
168
+ # pred_robs.append(rob.item()*100)
169
+ # print(accs,robs)
170
+ # print(pred_accs,pred_robs)
171
+ # np.savetxt("./results/accs.csv", np.array(accs), delimiter=",")
172
+ # np.savetxt("./results/robs.csv", np.array(robs), delimiter=",")
173
+ # np.savetxt("./results/pred_accs.csv", np.array(pred_accs), delimiter=",")
174
+ # np.savetxt("./results/pred_robs.csv", np.array(pred_robs), delimiter=",")
175
+
176
+
create_acc_rob_pred_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+
5
+ from proard.classification.data_providers.imagenet import ImagenetDataProvider
6
+ from proard.classification.run_manager import DistributedClassificationRunConfig, DistributedRunManager
7
+ from proard.model_zoo import DYN_net
8
+ from proard.nas.accuracy_predictor import AccuracyRobustnessDataset
9
+ import horovod.torch as hvd
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "-p", "--path", help="The path of cifar10", type=str, default="/dataset/cifar10"
13
+ )
14
+ parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
15
+ parser.add_argument(
16
+ "-b",
17
+ "--batch_size",
18
+ help="The batch on every device for validation",
19
+ type=int,
20
+ default=32,
21
+ )
22
+ parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
23
+ parser.add_argument(
24
+ "-n",
25
+ "--net",
26
+ metavar="DYNNET",
27
+ default="ResNet50",
28
+ choices=[
29
+ "ResNet50",
30
+ "MBV3",
31
+ "ProxylessNASNet",
32
+ "MBV2"
33
+ ],
34
+ help="Dynamic networks",
35
+ )
36
+ parser.add_argument(
37
+ "--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
38
+ )
39
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
40
+ parser.add_argument(
41
+ "--robust_mode", type=bool, default=True
42
+ )
43
+ parser.add_argument(
44
+ "--WPS", type=bool, default=True
45
+ )
46
+ parser.add_argument(
47
+ "--base", type=bool, default=False
48
+ )
49
+ # Initialize Horovod
50
+ hvd.init()
51
+ # Pin GPU to be used to process local rank (one GPU per process)
52
+ torch.cuda.set_device(hvd.local_rank())
53
+ num_gpus = hvd.size()
54
+
55
+ args = parser.parse_args()
56
+ if args.gpu == "all":
57
+ device_list = range(torch.cuda.device_count())
58
+ args.gpu = ",".join(str(_) for _ in device_list)
59
+ else:
60
+ device_list = [int(_) for _ in args.gpu.split(",")]
61
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
62
+ args.test_batch_size = args.batch_size # * max(len(device_list), 1)
63
+ ImagenetDataProvider.DEFAULT_PATH = args.path
64
+
65
+
66
+ distributed_run_config = DistributedClassificationRunConfig(**args.__dict__, num_replicas=num_gpus, rank=hvd.rank())
67
+ dyn_network = DYN_net(args.net, args.robust_mode , args.dataset, args.train_criterion, pretrained=True,run_config=distributed_run_config,WPS=args.WPS)
68
+ compression = hvd.Compression.none
69
+ distributed_run_manager = DistributedRunManager(".tmp/eval_subnet", dyn_network, distributed_run_config,compression,is_root=(hvd.rank() == 0),init=False)
70
+ distributed_run_manager.save_config()
71
+ # hvd broadcast
72
+ distributed_run_manager.broadcast()
73
+ acc_data = AccuracyRobustnessDataset("./acc_rob_data_WPS_{}_{}_{}".format(args.dataset,args.net,args.train_criterion))
74
+
75
+ acc_data.build_acc_rob_dataset(distributed_run_manager,dyn_network,image_size_list=[224 if args.dataset == "imagenet" else 32])
eval_ofa_net.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import torch
7
+ import argparse
8
+
9
+ from proard.classification.data_providers.imagenet import ImagenetDataProvider
10
+ from proard.classification.data_providers.cifar10 import Cifar10DataProvider
11
+ from proard.classification.data_providers.cifar100 import Cifar100DataProvider
12
+ from proard.classification.run_manager import ClassificationRunConfig, RunManager
13
+ from proard.model_zoo import DYN_net
14
+
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "-p", "--path", help="The path of imagenet", type=str, default="/dataset/imagenet"
19
+ )
20
+ parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
21
+ parser.add_argument(
22
+ "-b",
23
+ "--batch-size",
24
+ help="The batch on every device for validation",
25
+ type=int,
26
+ default=16,
27
+ )
28
+ parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
29
+ parser.add_argument(
30
+ "-n",
31
+ "--net",
32
+ metavar="DYNET",
33
+ default="ResNet50",
34
+ choices=[
35
+ "ResNet50",
36
+ "MBV3",
37
+ "ProxylessNASNet",
38
+ "MBV2",
39
+ "WideResNet"
40
+ ],
41
+ help="dynamic networks",
42
+ )
43
+ parser.add_argument(
44
+ "--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
45
+ )
46
+ parser.add_argument(
47
+ "--attack", type=str, default="autoattack" ,choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce']
48
+ )
49
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
50
+ parser.add_argument(
51
+ "--robust_mode", type=bool, default=True
52
+ )
53
+ parser.add_argument(
54
+ "--WPS", type=bool, default=False
55
+ )
56
+ parser.add_argument(
57
+ "--base", type=bool, default=False
58
+ )
59
+ args = parser.parse_args()
60
+ if args.gpu == "all":
61
+ device_list = range(torch.cuda.device_count())
62
+ args.gpu = ",".join(str(_) for _ in device_list)
63
+ else:
64
+ device_list = [int(_) for _ in args.gpu.split(",")]
65
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
66
+ args.batch_size = args.batch_size * max(len(device_list), 1)
67
+ ImagenetDataProvider.DEFAULT_PATH = args.path
68
+
69
+ run_config = ClassificationRunConfig(attack_type=args.attack,dataset= args.dataset, test_batch_size=args.batch_size, n_worker=args.workers,robust_mode=args.robust_mode)
70
+ dyn_network = DYN_net(args.net,args.robust_mode,args.dataset, args.train_criterion ,pretrained=True,run_config=run_config,WPS=args.WPS,base=args.base)
71
+ """ Randomly sample a sub-network,
72
+ you can also manually set the sub-network using:
73
+ dyn_network.set_active_subnet(ks=7, e=6, d=4)
74
+ """
75
+ if not args.base:
76
+ # dyn_network.set_active_subnet(ks=3, e=4, d=2)
77
+ dyn_network.set_active_subnet(d=2,e=0.35,w=1.0)
78
+ # dyn_network.sample_active_subnet()
79
+ # dyn_network.set_max_net()
80
+ subnet = dyn_network.get_active_subnet(preserve_weight=True)
81
+ # print(subnet)
82
+ else:
83
+ subnet = dyn_network
84
+ """ Test sampled subnet
85
+ """
86
+ run_manager = RunManager(".tmp/eval_subnet", subnet, run_config, init=False)
87
+ run_config.data_provider.assign_active_img_size(32)
88
+ run_manager.reset_running_statistics(net=subnet)
89
+
90
+ print("Test random subnet:")
91
+ # print(subnet.module_str)
92
+
93
+ loss, (top1, top5,robust1,robust5) = run_manager.validate(net=subnet,is_test=True)
94
+ print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f,\t robust1=%.1f,\t robust5=%.1f" % (loss, top1, top5,robust1,robust5))
hugging_face.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import interpreter_login
2
+ from huggingface_hub import upload_folder, delete_folder, upload_file
3
+ # interpreter_login()
4
+
5
+
6
+
7
+ # upload_folder(folder_path = "attacks/",path_in_repo="attacks", repo_id="smi08/ProArd")
8
+ # upload_folder(folder_path = "images/",path_in_repo="images", repo_id="smi08/ProArd")
9
+ # upload_folder(folder_path = "proard/",path_in_repo="proard", repo_id="smi08/ProArd")
10
+ # upload_folder(folder_path = "robust_loss/",path_in_repo="robust_loss", repo_id="smi08/ProArd")
11
+ # upload_folder(folder_path = "utils/",path_in_repo="utils", repo_id="smi08/ProArd")
12
+ # delete_folder(path_in_repo="smi08", repo_id="smi08/ProArd")
13
+ upload_file(path_or_fileobj="create_acc_rob_pred_dataset.py",path_in_repo="",repo_id="smi08/ProArd")
14
+ upload_file(path_or_fileobj="create_acc_rob_pred.py",path_in_repo="",repo_id="smi08/ProArd")
15
+ upload_file(path_or_fileobj="eval_ofa_net.py",path_in_repo="",repo_id="smi08/ProArd")
16
+ upload_file(path_or_fileobj="sample_eval.py",path_in_repo="",repo_id="smi08/ProArd")
17
+ upload_file(path_or_fileobj="search_best.py",path_in_repo="",repo_id="smi08/ProArd")
18
+ upload_file(path_or_fileobj="train_ofa_net_WPS.py",path_in_repo="",repo_id="smi08/ProArd")
19
+ upload_file(path_or_fileobj="train_ofa_net.py",path_in_repo="",repo_id="smi08/ProArd")
20
+ upload_file(path_or_fileobj="train_teacher_net.py",path_in_repo="",repo_id="smi08/ProArd")
21
+ upload_file(path_or_fileobj="README.md",path_in_repo="",repo_id="smi08/ProArd")
sample_eval.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import os
6
+ import torch
7
+ import argparse
8
+ import sys
9
+ from proard.classification.data_providers.imagenet import ImagenetDataProvider
10
+ from proard.classification.data_providers.cifar10 import Cifar10DataProvider
11
+ from proard.classification.data_providers.cifar100 import Cifar100DataProvider
12
+ from proard.classification.run_manager import ClassificationRunConfig, RunManager,DistributedRunManager
13
+ from proard.model_zoo import DYN_net
14
+ from proard.nas.accuracy_predictor import AccuracyDataset,AccuracyPredictor,ResNetArchEncoder,RobustnessPredictor,MobileNetArchEncoder,AccuracyRobustnessDataset,Accuracy_Robustness_Predictor
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "-p", "--path", help="The path of imagenet", type=str, default="/dataset/imagenet"
19
+ )
20
+ parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
21
+ parser.add_argument(
22
+ "-b",
23
+ "--batch-size",
24
+ help="The batch on every device for validation",
25
+ type=int,
26
+ default=128,
27
+ )
28
+ parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
29
+ parser.add_argument(
30
+ "-n",
31
+ "--net",
32
+ metavar="DYNNET",
33
+ default="MBV3",
34
+ choices=[
35
+ "ResNet50",
36
+ "MBV3",
37
+ "ProxylessNASNet",
38
+ "MBV2"
39
+ ],
40
+ help="dynamic networks",
41
+ )
42
+ parser.add_argument(
43
+ "--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
44
+ )
45
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
46
+ parser.add_argument(
47
+ "--robust_mode", type=bool, default=True
48
+ )
49
+ parser.add_argument(
50
+ "--WPS", type=bool, default=False
51
+ )
52
+ args = parser.parse_args()
53
+ if args.gpu == "all":
54
+ device_list = range(torch.cuda.device_count())
55
+ args.gpu = ",".join(str(_) for _ in device_list)
56
+ else:
57
+ device_list = [int(_) for _ in args.gpu.split(",")]
58
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
59
+ args.batch_size = args.batch_size * max(len(device_list), 1)
60
+ ImagenetDataProvider.DEFAULT_PATH = args.path
61
+
62
+ run_config = ClassificationRunConfig(dataset= args.dataset, test_batch_size=args.batch_size, n_worker=args.workers,robust_mode=args.robust_mode)
63
+ dyn_network = DYN_net(args.net,args.robust_mode,args.dataset, args.train_criterion ,pretrained=True,run_config=run_config,WPS=args.WPS)
64
+ """ Randomly sample a sub-network,
65
+ you can also manually set the sub-network using:
66
+ dyn_network.set_active_subnet(ks=7, e=6, d=4)
67
+ """
68
+ # dyn_network.set_active_subnet(ks=3, e=3, d=2)
69
+ # dyn_network.set_active_subnet(d=4,e=0.25,w=1)
70
+ import random
71
+ import numpy as np
72
+ random.seed(0)
73
+ np.random.seed(0)
74
+ acc1,rob1,acc2,rob2 =[],[],[],[]
75
+ if args.net == "ResNet50":
76
+ arch = ResNetArchEncoder(image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[0,1,2],expand_list=[0.2,0.25,0.35],width_mult_list=[0.65,0.8,1.0])
77
+ else:
78
+ arch = MobileNetArchEncoder (image_size_list=[224 if args.dataset == 'imagenet' else 32],depth_list=[2,3,4],expand_list=[3,4,6],ks_list=[3,5,7])
79
+ print(arch)
80
+ acc_data = AccuracyRobustnessDataset("./acc_rob_data_{}_{}_{}".format(args.dataset,args.net,args.train_criterion))
81
+ train_loader, valid_loader, base_acc ,base_rob = acc_data.build_acc_data_loader(arch)
82
+ for inputs, targets_acc, targets_rob in train_loader:
83
+ for i in range(len(targets_acc)):
84
+ acc1.append(targets_acc[i].item() * 100)
85
+ rob1.append(targets_rob[i].item() * 100)
86
+
87
+ np.save("./results/acc_mbv3.npy",np.array(acc1))
88
+ np.save("./results/rob_mbv3.npy",np.array(rob1))
89
+
search_best.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from tqdm.auto import tqdm
7
+ from torch.utils.data import DataLoader
8
+ import torch.nn as nn
9
+ import torch
10
+ import random
11
+ from torch import nn
12
+ from torch.optim import *
13
+ from torch.optim.lr_scheduler import *
14
+ from torch.utils.data import DataLoader
15
+ from torchprofile import profile_macs
16
+ from torchvision.datasets import *
17
+ from torchvision.transforms import *
18
+ from proard.model_zoo import DYN_net
19
+ from proard.nas.accuracy_predictor import AccuracyPredictor,ResNetArchEncoder,RobustnessPredictor,MobileNetArchEncoder,Accuracy_Robustness_Predictor
20
+ from proard.nas.efficiency_predictor import ResNet50FLOPsModel,Mbv3FLOPsModel,ProxylessNASFLOPsModel
21
+ from proard.nas.search_algorithm import EvolutionFinder,DynIndividual_mbv,DynIndividual_res,DynRandomSampler,DynProblem_mbv,DynProblem_res,DynSampling,individual_to_arch_res,individual_to_arch_mbv
22
+ from utils.profile import trainable_param_num
23
+ from pymoo.core.individual import Individual
24
+ from pymoo.core.mutation import Mutation
25
+ from pymoo.core.population import Population
26
+ from pymoo.core.problem import Problem
27
+ from pymoo.core.sampling import Sampling
28
+ from pymoo.core.variable import Choice
29
+ from pymoo.operators.crossover.ux import UniformCrossover
30
+ from pymoo.operators.mutation.pm import PolynomialMutation
31
+ from pymoo.operators.mutation.rm import ChoiceRandomMutation
32
+ from pymoo.operators.selection.rnd import RandomSelection
33
+ from pymoo.operators.selection.tournament import TournamentSelection
34
+ from pymoo.algorithms.moo.nsga2 import NSGA2
35
+ from pymoo.algorithms.moo.sms import SMSEMOA
36
+ from pymoo.algorithms.moo.spea2 import SPEA2
37
+ from pymoo.optimize import minimize
38
+ from pymoo.termination import get_termination
39
+ from pymoo.termination.default import DefaultMultiObjectiveTermination
40
+ from pymoo.core.callback import Callback
41
+ from pymoo.util.display.column import Column
42
+ from pymoo.util.display.output import Output
43
+ from proard.classification.run_manager import ClassificationRunConfig, RunManager
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "-p", "--path", help="The path of cifar10", type=str, default="/dataset/cifar10"
47
+ )
48
+ parser.add_argument("-g", "--gpu", help="The gpu(s) to use", type=str, default="all")
49
+ parser.add_argument(
50
+ "-b",
51
+ "--batch-size",
52
+ help="The batch on every device for validation",
53
+ type=int,
54
+ default=100,
55
+ )
56
+ parser.add_argument("-j", "--workers", help="Number of workers", type=int, default=20)
57
+ parser.add_argument(
58
+ "-n",
59
+ "--net",
60
+ metavar="DYNNET",
61
+ default="ResNet50",
62
+ choices=[
63
+ "ResNet50",
64
+ "MBV3",
65
+ "ProxylessNASNet",
66
+ ],
67
+ help="dynamic networks",
68
+ )
69
+ parser.add_argument(
70
+ "--dataset", type=str, default="cifar10" ,choices=["cifar10", "cifar100", "imagenet"]
71
+ )
72
+ parser.add_argument(
73
+ "--attack", type=str, default="linf-pgd" ,choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce']
74
+ )
75
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
76
+ parser.add_argument(
77
+ "--robust_mode", type=bool, default=True
78
+ )
79
+ args = parser.parse_args()
80
+ if args.gpu == "all":
81
+ device_list = range(torch.cuda.device_count())
82
+ args.gpu = ",".join(str(_) for _ in device_list)
83
+ else:
84
+ device_list = [int(_) for _ in args.gpu.split(",")]
85
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
86
+ args.batch_size = args.batch_size * max(len(device_list), 1)
87
+ run_config = ClassificationRunConfig(attack_type=args.attack, dataset= args.dataset, test_batch_size=args.batch_size, n_worker=args.workers,robust_mode=args.robust_mode)
88
+ dyn_network = DYN_net(args.net,args.robust_mode,args.dataset,args.train_criterion, pretrained=True,run_config=run_config)
89
+ if args.net == "ResNet50":
90
+ efficiency_predictor = ResNet50FLOPsModel(dyn_network)
91
+ arch = ResNetArchEncoder(image_size_list=[32],depth_list=[0,1,2],expand_list=[0.2,0.25,0.35],width_mult_list=[0.65,0.8,1.0])
92
+ accuracy_robustness_predictor = Accuracy_Robustness_Predictor(arch)
93
+ accuracy_robustness_predictor.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
94
+ elif args.net == "MBV3":
95
+ efficiency_predictor = Mbv3FLOPsModel(dyn_network)
96
+ arch = MobileNetArchEncoder(image_size_list=[32],depth_list=[2,3,4],expand_list=[3,4,6],ks_list=[3,5,7])
97
+ accuracy_robustness_predictor = Accuracy_Robustness_Predictor(arch)
98
+ accuracy_robustness_predictor.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
99
+ elif args.net == "ProxylessNASNet":
100
+ efficiency_predictor = ProxylessNASFLOPsModel(dyn_network)
101
+ arch = MobileNetArchEncoder(image_size_list=[32],depth_list=[2,3,4],expand_list=[3,4,6],width_mult_list=[3,5,7])
102
+ accuracy_robustness_predictor = Accuracy_Robustness_Predictor(arch)
103
+ accuracy_robustness_predictor.load_state_dict(torch.load("./acc_rob_data_{}_{}_{}/src/model_acc_rob.pth".format(args.dataset,args.net,args.train_criterion)))
104
+ ##### Test #################################################
105
+ dyn_sampler = DynRandomSampler(arch, efficiency_predictor)
106
+ # arch1, eff1 = dyn_sampler.random_sample()
107
+ # arch2, eff2 = dyn_sampler.random_sample()
108
+ # print(accuracy_predictor.predict_acc([arch1, arch2]))
109
+ # print(arch1,eff1)
110
+ ##################################################
111
+
112
+ """ Hyperparameters
113
+ - P: size of the population in each generation (number of individuals)
114
+ - N: number of generations to run the algorithm
115
+ - mutate_prob: probability of gene mutation in the evolutionary search
116
+ """
117
+ P = 100
118
+ N = 100
119
+ mutation_prob = 0.5
120
+
121
+
122
+ # variables options
123
+ if args.net == 'ResNet50':
124
+ search_space = {
125
+ 'e': [0.2, 0.25, 0.35],
126
+ 'd': [0, 1, 2],
127
+ 'w': [0 ,1 ,2],
128
+ 'image_size': [32]
129
+ }
130
+ else:
131
+ search_space = {
132
+ 'ks': [3, 5, 7],
133
+ 'e': [3, 4, 6],
134
+ 'd': [2, 3, 4],
135
+ 'image_size': [32]
136
+ }
137
+
138
+ #----------------------------
139
+ # units
140
+ num_blocks = arch.max_n_blocks
141
+ num_stages = arch.n_stage
142
+ Flops_constraints = 1600
143
+ if args.net == "ResNet50":
144
+ problem = DynProblem_res(efficiency_predictor, accuracy_robustness_predictor, num_blocks, num_stages, search_space,Flops_constraints)
145
+ else:
146
+ problem = DynProblem_mbv(efficiency_predictor, accuracy_robustness_predictor, num_blocks, num_stages, search_space,Flops_constraints)
147
+
148
+
149
+
150
+
151
+
152
+ mutation_rc = ChoiceRandomMutation(prob=1.0, prob_var=0.1)
153
+ crossover_ux = UniformCrossover(prob=1.0)
154
+ # selection_tournament = TournamentSelection(
155
+ # func_comp=accuracy_predictor.predict_acc,
156
+ # pressure=2
157
+ # )
158
+ termination_default = DefaultMultiObjectiveTermination(
159
+ xtol=1e-8, cvtol=1e-6, ftol=0.0025, period=30, n_max_gen=1000, n_max_evals=100000
160
+ )
161
+ termination_gen = get_termination("n_gen", N)
162
+ np.random.seed(42)
163
+ random.seed(42)
164
+ if args.net=="ResNet50":
165
+ init_pop = Population(individuals=[DynIndividual_res(dyn_sampler.random_sample(), accuracy_robustness_predictor) for _ in range(P)])
166
+ else:
167
+ init_pop = Population(individuals=[DynIndividual_mbv(dyn_sampler.random_sample(), accuracy_robustness_predictor) for _ in range(P)])
168
+
169
+ algorithm = NSGA2(
170
+ pop_size=P,
171
+ sampling=DynSampling(),
172
+ # selection=selection_tournament,
173
+ crossover=crossover_ux,
174
+ mutation=mutation_rc,
175
+ # mutation=mutation_pm,
176
+ # survival=RankAndCrowdingSurvival(),
177
+ # output=MultiObjectiveOutput(),
178
+ # **kwargs
179
+ )
180
+ res_nsga2 = minimize(
181
+ problem,
182
+ algorithm,
183
+ termination=termination_gen,
184
+ seed=1,
185
+ #verbose=True,
186
+ verbose=False,
187
+ save_history=True,
188
+ )
189
+ # print(100-res_nsga2.history[99].pop.get('F')[:,0],100-res_nsga2.history[99].pop.get('F')[:,1])
190
+ # a = individual_to_arch_res(res_nsga2.pop.get('X'),num_blocks)[0]
191
+ # # print(a)
192
+ # # a['d'][3] = int(a['d'][3])
193
+ # a['d'][4] = int(a['d'][4])
194
+ # dyn_network.set_active_subnet(**a)
195
+ # subnet = dyn_network.get_active_subnet(preserve_weight=True)
196
+ # run_manager = RunManager(".tmp/eval_subnet", subnet, run_config, init=False)
197
+ # run_config.data_provider.assign_active_img_size(32)
198
+ # run_manager.reset_running_statistics(net=subnet)
199
+
200
+ # print("Test random subnet:")
201
+ # # print(subnet.module_str)
202
+
203
+ # loss, (top1, top5,robust1,robust5) = run_manager.validate(net=subnet,is_test=True)
204
+ # print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f,\t robust1=%.1f,\t robust5=%.1f" % (loss, top1, top5,robust1,robust5))
205
+
206
+
207
+ np.savetxt("./results/acc_gen0.csv", 100-res_nsga2.history[0].pop.get('F')[:,0], delimiter=",")
208
+
209
+ np.savetxt("./results/acc_gen99.csv", 100-res_nsga2.history[99].pop.get('F')[:,0], delimiter=",")
210
+ np.savetxt("./results/rob_gen0.csv", 100-res_nsga2.history[0].pop.get('F')[:,1], delimiter=",")
211
+
212
+ np.savetxt("./results/rob_gen99.csv", 100-res_nsga2.history[99].pop.get('F')[:,1], delimiter=",")
213
+ np.savetxt("./results/flops_gen99.csv", res_nsga2.history[99].pop.get('G'), delimiter=",")
214
+
215
+ # np.savetxt("./results/robs.csv", np.array(robs), delimiter=",")
216
+
217
+ from matplotlib import pyplot as plt
218
+ from matplotlib.ticker import FormatStrFormatter
219
+ from matplotlib.ticker import AutoMinorLocator, MultipleLocator
220
+ # NSGA-II population progression
221
+ x_min, x_max, y_min, y_max = 80, 93, 47, 56
222
+ ax_limits = [x_min, x_max, y_min, y_max]
223
+ #-------------------------------------------------
224
+ # plot
225
+ fig, ax = plt.subplots(dpi=600)
226
+ gen0 = 0
227
+ gen1 = 99
228
+ print(100-res_nsga2.history[gen1].pop.get('F')[:,0], 100 - res_nsga2.history[gen1].pop.get('F')[:,1])
229
+ # gen2 = 99
230
+ # print(res_nsga2.history[gen0].pop.get('F')[:,0],res_nsga2.history[gen0].pop.get('F')[:,1] )
231
+ ax.plot(100-res_nsga2.history[gen0].pop.get('F')[:,0], 100 - res_nsga2.history[gen0].pop.get('F')[:,1] , 'o', label=f'Population at generation #{gen0+1}', color='red', alpha=0.5)
232
+ ax.plot(100-res_nsga2.history[gen1].pop.get('F')[:,0], 100 - res_nsga2.history[gen1].pop.get('F')[:,1] , 'o', label=f'Population at generation #{gen1+1}', color='green', alpha=0.5)
233
+ # ax.plot(res_nsga2.history[gen2].pop.get('F')[:,0], 100 - res_nsga2.history[gen2].pop.get('F')[:,1], 'o', label=f'Population at generation #{gen2+1}', color='orange', alpha=0.5)
234
+ # ax.plot(res_nsga2.history[gen3].pop.get('F')[:,0], 100 - res_nsga2.history[gen3].pop.get('F')[:,1], 'o', label=f'Population at generation #{gen3+1}', color='blue', alpha=0.5)
235
+ #-------------------------------------------------
236
+ # text
237
+ ax.grid(True, linestyle=':')
238
+ ax.set_xlabel('Accuracy (%)')
239
+ ax.set_ylabel('Robustness (%)')
240
+ ax.set_title('NSGA-II solutions progression For Fixed number of FLOPs'),
241
+ ax.legend()
242
+ #-------------------------------------------------
243
+ # x-axis
244
+ ax.xaxis.set_major_locator(MultipleLocator(1))
245
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
246
+ # y-axis
247
+ ax.yaxis.set_major_locator(MultipleLocator(1))
248
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
249
+ # ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
250
+ ax.set(xlim=(ax_limits[0], ax_limits[1]), ylim=(ax_limits[2], ax_limits[3]))
251
+ #-------------------------------------------------
252
+ plt.savefig('nsga2_pop_progression_debug.png')
253
+ fig.set_dpi(100)
254
+ # plt.close(fig)
255
+
256
+
257
+
258
+ # plt.show()
259
+
260
+
261
+ # finder = EvolutionFinder(efficiency_predictor,accuracy_predictor,Robustness_predictor)
262
+ # valid_constraint_range = 800
263
+ # best_valids, best_info = finder.run_evolution_search(constraint=valid_constraint_range,verbose=True)
264
+ # print(efficiency_predictor.get_efficiency(best_info[2]))
265
+ # dyn_network.set_active_subnet(best_info[2]['d'],best_info[2]['e'],best_info[2]['w'])
266
+ # subnet = dyn_network.get_active_subnet(preserve_weight=True)
267
+ # run_config = CifarRunConfig_robust(test_batch_size=args.batch_size, n_worker=args.workers)
268
+ # run_manager = RunManager_robust(".tmp/eval_subnet", subnet, run_config, init=False)
269
+ # run_config.data_provider.assign_active_img_size(32)
270
+ # run_manager.reset_running_statistics(net=subnet)
271
+ # loss, (top1, top5,robust1,robust5) = run_manager.validate(net=subnet)
272
+ # print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f,\t robust1=%.1f,\t robust5=%.1f" % (loss, top1, top5,robust1,robust5))
273
+ # print("number of parameter={}M".format(trainable_param_num(subnet)))
train_ofa_net.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import argparse
6
+ import numpy as np
7
+ import os
8
+ import random
9
+
10
+ # using for distributed training
11
+ import horovod.torch as hvd
12
+ import torch
13
+
14
+
15
+ from proard.classification.elastic_nn.modules.dynamic_op import (
16
+ DynamicSeparableConv2d,
17
+ )
18
+ from proard.classification.elastic_nn.networks import DYNMobileNetV3,DYNProxylessNASNets,DYNResNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar
19
+ from proard.classification.run_manager import DistributedClassificationRunConfig
20
+ from proard.classification.run_manager.distributed_run_manager import (
21
+ DistributedRunManager
22
+ )
23
+ from proard.utils import download_url, MyRandomResizedCrop
24
+ from proard.classification.elastic_nn.training.progressive_shrinking import load_models
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument(
28
+ "--task",
29
+ type=str,
30
+ default="expand",
31
+ choices=[
32
+ "kernel", # for architecture except ResNet
33
+ "depth",
34
+ "expand",
35
+ "width", # only for ResNet
36
+ ],
37
+ )
38
+ parser.add_argument("--phase", type=int, default=2, choices=[1, 2])
39
+ parser.add_argument("--resume", action="store_true")
40
+ parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet","MBV2"])
41
+ parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar10", "cifar100", "imagenet"])
42
+ parser.add_argument("--robust_mode", type=bool, default=True)
43
+ parser.add_argument("--epsilon", type=float, default=0.031)
44
+ parser.add_argument("--num_steps", type=int, default=10)
45
+ parser.add_argument("--step_size", type=float, default=0.0078)
46
+ parser.add_argument("--clip_min", type=int, default=0)
47
+ parser.add_argument("--clip_max", type=int, default=1)
48
+ parser.add_argument("--const_init", type=bool, default=False)
49
+ parser.add_argument("--beta", type=float, default=6.0)
50
+ parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"])
51
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
52
+ parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"])
53
+ parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"])
54
+ parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce'])
55
+
56
+ args = parser.parse_args()
57
+ if args.model_name == "ResNet50":
58
+ args.ks_list = "3"
59
+ if args.task == "width":
60
+ if args.robust_mode:
61
+ args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2width"
62
+ else:
63
+ args.path = "exp/"+ args.dataset + '/' +args.model_name +'/' + args.train_criterion +"/normal2width"
64
+ args.dynamic_batch_size = 1
65
+ args.n_epochs = 120
66
+ args.base_lr = 3e-2
67
+ args.warmup_epochs = 5
68
+ args.warmup_lr = -1
69
+ args.width_mult_list = "0.65,0.8,1.0"
70
+ args.expand_list = "0.35"
71
+ args.depth_list = "2"
72
+ elif args.task == "depth":
73
+ if args.robust_mode:
74
+ args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
75
+ else:
76
+ args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
77
+ args.dynamic_batch_size = 2
78
+ if args.phase == 1:
79
+ args.n_epochs = 25
80
+ args.base_lr = 2.5e-3
81
+ args.warmup_epochs = 0
82
+ args.warmup_lr = -1
83
+ args.width_mult_list = "0.65,0.8,1.0"
84
+ args.expand_list ="0.35"
85
+ args.depth_list = "1,2"
86
+ else:
87
+ args.n_epochs = 120
88
+ args.base_lr = 7.5e-3
89
+ args.warmup_epochs = 5
90
+ args.warmup_lr = -1
91
+ args.width_mult_list = "0.65,0.8,1.0"
92
+ args.expand_list = "0.35"
93
+ args.depth_list = "0,1,2"
94
+ elif args.task == "expand":
95
+ if args.robust_mode :
96
+ args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
97
+ else:
98
+ args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
99
+ args.dynamic_batch_size = 4
100
+ if args.phase == 1:
101
+ args.n_epochs = 25
102
+ args.base_lr = 2.5e-3
103
+ args.warmup_epochs = 0
104
+ args.warmup_lr = -1
105
+ args.width_mult_list = "0.65,0.8,1.0"
106
+ args.expand_list = "0.25,0.35"
107
+ args.depth_list = "0,1,2"
108
+ else:
109
+ args.n_epochs = 120
110
+ args.base_lr = 7.5e-3
111
+ args.warmup_epochs = 5
112
+ args.warmup_lr = -1
113
+ args.width_mult_list = "0.65,0.8,1.0"
114
+ args.expand_list = "0.2,0.25,0.35"
115
+ args.depth_list = "0,1,2"
116
+ else:
117
+ raise NotImplementedError
118
+ else:
119
+ args.width_mult_list = "1.0"
120
+ if args.task == "kernel":
121
+ if args.robust_mode:
122
+ args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
123
+ else:
124
+ args.path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
125
+ args.dynamic_batch_size = 1
126
+ args.n_epochs = 120
127
+ args.base_lr = 3e-2
128
+ args.warmup_epochs = 5
129
+ args.warmup_lr = -1
130
+ args.ks_list = "3,5,7"
131
+ args.expand_list = "6"
132
+ args.depth_list = "4"
133
+ elif args.task == "depth":
134
+ if args.robust_mode :
135
+ args.path = "exp/robust/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
136
+ else:
137
+ args.path = "exp/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
138
+ args.dynamic_batch_size = 2
139
+ if args.phase == 1:
140
+ args.n_epochs = 25
141
+ args.base_lr = 2.5e-3
142
+ args.warmup_epochs = 0
143
+ args.warmup_lr = -1
144
+ args.ks_list = "3,5,7"
145
+ args.expand_list = "6"
146
+ args.depth_list = "3,4"
147
+ else:
148
+ args.n_epochs = 120
149
+ args.base_lr = 7.5e-3
150
+ args.warmup_epochs = 5
151
+ args.warmup_lr = -1
152
+ args.ks_list = "3,5,7"
153
+ args.expand_list = "6"
154
+ args.depth_list = "2,3,4"
155
+ elif args.task == "expand":
156
+ if args.robust_mode:
157
+ args.path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase%d" % args.phase
158
+ else:
159
+ args.path = "exp/"+ args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/kernel_depth2kernel_depth_width/phase%d" % args.phase
160
+ args.dynamic_batch_size = 4
161
+ if args.phase == 1:
162
+ args.n_epochs = 25
163
+ args.base_lr = 2.5e-3
164
+ args.warmup_epochs = 0
165
+ args.warmup_lr = -1
166
+ args.ks_list = "3,5,7"
167
+ args.expand_list = "4,6"
168
+ args.depth_list = "2,3,4"
169
+ else:
170
+ args.n_epochs = 120
171
+ args.base_lr = 7.5e-3
172
+ args.warmup_epochs = 5
173
+ args.warmup_lr = -1
174
+ args.ks_list = "3,5,7"
175
+ args.expand_list = "3,4,6"
176
+ args.depth_list = "2,3,4"
177
+ else:
178
+ raise NotImplementedError
179
+ args.manual_seed = 0
180
+
181
+ args.lr_schedule_type = "cosine"
182
+
183
+ args.base_batch_size = 64
184
+ args.valid_size = 64
185
+
186
+ args.opt_type = "sgd"
187
+ args.momentum = 0.9
188
+ args.no_nesterov = False
189
+ args.weight_decay = 3e-5
190
+ args.label_smoothing = 0.1
191
+ args.no_decay_keys = "bn#bias"
192
+ args.fp16_allreduce = False
193
+
194
+ args.model_init = "he_fout"
195
+ args.validation_frequency = 1
196
+ args.print_frequency = 10
197
+
198
+ args.n_worker = 8
199
+ args.resize_scale = 0.08
200
+ args.distort_color = "tf"
201
+ if args.dataset == "imagenet":
202
+ args.image_size = "128,160,192,224"
203
+ else:
204
+ args.image_size = "32"
205
+ args.continuous_size = True
206
+ args.not_sync_distributed_image_size = False
207
+
208
+ args.bn_momentum = 0.1
209
+ args.bn_eps = 1e-5
210
+ args.dropout = 0.1
211
+ args.base_stage_width = "google"
212
+
213
+
214
+ args.dy_conv_scaling_mode = 1
215
+ args.independent_distributed_sampling = False
216
+
217
+ args.kd_ratio = 1.0
218
+ args.kd_type = "ce"
219
+
220
+
221
+ if __name__ == "__main__":
222
+ os.makedirs(args.path, exist_ok=True)
223
+
224
+ # Initialize Horovod
225
+ hvd.init()
226
+ # Pin GPU to be used to process local rank (one GPU per process)
227
+ torch.cuda.set_device(hvd.local_rank())
228
+ if args.robust_mode:
229
+ args.teacher_path = 'exp/robust/teacher/' + args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
230
+ else:
231
+ args.teacher_path = 'exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
232
+ num_gpus = hvd.size()
233
+
234
+ torch.manual_seed(args.manual_seed)
235
+ torch.cuda.manual_seed_all(args.manual_seed)
236
+ np.random.seed(args.manual_seed)
237
+ random.seed(args.manual_seed)
238
+
239
+ # image size
240
+ args.image_size = [int(img_size) for img_size in args.image_size.split(",")]
241
+ if len(args.image_size) == 1:
242
+ args.image_size = args.image_size[0]
243
+ MyRandomResizedCrop.CONTINUOUS = args.continuous_size
244
+ MyRandomResizedCrop.SYNC_DISTRIBUTED = not args.not_sync_distributed_image_size
245
+
246
+ # build run config from args
247
+ args.lr_schedule_param = None
248
+ args.opt_param = {
249
+ "momentum": args.momentum,
250
+ "nesterov": not args.no_nesterov,
251
+ }
252
+ args.init_lr = args.base_lr * num_gpus # linearly rescale the learning rate
253
+ if args.warmup_lr < 0:
254
+ args.warmup_lr = args.base_lr
255
+ args.train_batch_size = args.base_batch_size
256
+ args.test_batch_size = args.base_batch_size * 4
257
+ run_config = DistributedClassificationRunConfig(
258
+ **args.__dict__, num_replicas=num_gpus, rank=hvd.rank()
259
+ )
260
+
261
+ # print run config information
262
+ if hvd.rank() == 0:
263
+ print("Run config:")
264
+ for k, v in run_config.config.items():
265
+ print("\t%s: %s" % (k, v))
266
+
267
+ if args.dy_conv_scaling_mode == -1:
268
+ args.dy_conv_scaling_mode = None
269
+ DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode
270
+
271
+ # build net from args
272
+ args.width_mult_list = [
273
+ float(width_mult) for width_mult in args.width_mult_list.split(",")
274
+ ]
275
+ args.ks_list = [int(ks) for ks in args.ks_list.split(",")]
276
+ if args.model_name == "ResNet50":
277
+ args.expand_list = [float(e) for e in args.expand_list.split(",")]
278
+ else:
279
+ args.expand_list = [int(e) for e in args.expand_list.split(",")]
280
+ args.depth_list = [int(d) for d in args.depth_list.split(",")]
281
+
282
+ args.width_mult_list = (
283
+ args.width_mult_list[0]
284
+ if len(args.width_mult_list) == 1
285
+ else args.width_mult_list
286
+ )
287
+
288
+ if args.model_name == "ResNet50":
289
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
290
+ net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes,
291
+ bn_param=(args.bn_momentum, args.bn_eps),
292
+ dropout_rate=args.dropout,
293
+ depth_list=args.depth_list,
294
+ expand_ratio_list=args.expand_list,
295
+ width_mult_list=args.width_mult_list,)
296
+ else:
297
+ net = DYNResNets( n_classes=run_config.data_provider.n_classes,
298
+ bn_param=(args.bn_momentum, args.bn_eps),
299
+ dropout_rate=args.dropout,
300
+ depth_list=args.depth_list,
301
+ expand_ratio_list=args.expand_list,
302
+ width_mult_list=args.width_mult_list,)
303
+ elif args.model_name == "MBV3":
304
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
305
+ net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
306
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
307
+ else:
308
+ net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
309
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
310
+ elif args.model_name == "ProxylessNASNet":
311
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
312
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
313
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
314
+ else:
315
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
316
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
317
+ elif args.model_name == "MBV2":
318
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
319
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
320
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
321
+ else:
322
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
323
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
324
+ else:
325
+ raise NotImplementedError
326
+ # teacher model
327
+ if args.kd_ratio > 0:
328
+
329
+ if args.model_name =="ResNet50":
330
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
331
+ args.teacher_model = DYNResNets_Cifar(
332
+ n_classes=run_config.data_provider.n_classes,
333
+ bn_param=(args.bn_momentum, args.bn_eps),
334
+ dropout_rate=args.dropout,
335
+ depth_list=[2],
336
+ expand_ratio_list=[0.35],
337
+ width_mult_list=[1.0],
338
+ )
339
+ else:
340
+ args.teacher_model = DYNResNets(
341
+ n_classes=run_config.data_provider.n_classes,
342
+ bn_param=(args.bn_momentum, args.bn_eps),
343
+ dropout_rate=args.dropout,
344
+ depth_list=[2],
345
+ expand_ratio_list=[0.35],
346
+ width_mult_list=[1.0],
347
+ )
348
+ elif args.model_name =="MBV3":
349
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
350
+ args.teacher_model = DYNMobileNetV3_Cifar(
351
+ n_classes=run_config.data_provider.n_classes,
352
+ bn_param=(args.bn_momentum, args.bn_eps),
353
+ dropout_rate=0,
354
+ width_mult=1.0,
355
+ ks_list=[7],
356
+ expand_ratio_list=[6],
357
+ depth_list=[4]
358
+ )
359
+ else:
360
+ args.teacher_model = DYNMobileNetV3(
361
+ n_classes=run_config.data_provider.n_classes,
362
+ bn_param=(args.bn_momentum, args.bn_eps),
363
+ dropout_rate=0,
364
+ width_mult=1.0,
365
+ ks_list=[7],
366
+ expand_ratio_list=[6],
367
+ depth_list=[4]
368
+ )
369
+ elif args.model_name == "ProxylessNASNet":
370
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
371
+ args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
372
+ bn_param=(args.bn_momentum, args.bn_eps),
373
+ dropout_rate=0,
374
+ width_mult=1.0,
375
+ ks_list=[7],
376
+ expand_ratio_list=[6],
377
+ depth_list=[4])
378
+ else:
379
+ args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
380
+ bn_param=(args.bn_momentum, args.bn_eps),
381
+ dropout_rate=0,
382
+ width_mult=1.0,
383
+ ks_list=[7],
384
+ expand_ratio_list=[6],
385
+ depth_list=[4])
386
+ elif args.model_name == "MBV2":
387
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
388
+ args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
389
+ bn_param=(args.bn_momentum, args.bn_eps),
390
+ dropout_rate=0,
391
+ width_mult=1.0,
392
+ ks_list=[7],
393
+ expand_ratio_list=[6],
394
+ depth_list=[4],base_stage_width=args.base_stage_width)
395
+ else:
396
+ args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
397
+ bn_param=(args.bn_momentum, args.bn_eps),
398
+ dropout_rate=0,
399
+ width_mult=1.0,
400
+ ks_list=[7],
401
+ expand_ratio_list=[6],
402
+ depth_list=[4],base_stage_width=args.base_stage_width)
403
+ args.teacher_model.cuda()
404
+
405
+ """ Distributed RunManager """
406
+ # Horovod: (optional) compression algorithm.
407
+ compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
408
+ distributed_run_manager = DistributedRunManager(
409
+ args.path,
410
+ net,
411
+ run_config,
412
+ compression,
413
+ backward_steps=args.dynamic_batch_size,
414
+ is_root=(hvd.rank() == 0),
415
+ )
416
+ distributed_run_manager.save_config()
417
+ # hvd broadcast
418
+ distributed_run_manager.broadcast()
419
+
420
+ # load teacher net weights
421
+ if args.kd_ratio > 0:
422
+ load_models(
423
+ distributed_run_manager, args.teacher_model, model_path=args.teacher_path
424
+ )
425
+
426
+ # training
427
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
428
+ validate,
429
+ train,
430
+ )
431
+ if args.model_name =="ResNet50":
432
+ validate_func_dict = {
433
+ "image_size_list": {224 if args.dataset == "imagenet" else 32}
434
+ if isinstance(args.image_size, int)
435
+ else sorted({160, 224}),
436
+ "width_mult_list": sorted({min(args.width_mult_list), max(args.width_mult_list)}),
437
+ "expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
438
+ "depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
439
+ }
440
+ else:
441
+ validate_func_dict = {
442
+ "image_size_list": {224 if args.dataset == "imagenet" else 32}
443
+ if isinstance(args.image_size, int)
444
+ else sorted({160, 224}),
445
+ "width_mult_list": [1.0],
446
+ "ks_list": sorted({min(args.ks_list), max(args.ks_list)}),
447
+ "expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
448
+ "depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
449
+ }
450
+
451
+ if args.task == "width":
452
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
453
+ train_elastic_width_mult,
454
+ )
455
+ if distributed_run_manager.start_epoch == 0:
456
+ if args.robust_mode:
457
+ args.dyn_checkpoint_path ='exp/robust/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
458
+ else:
459
+ args.dyn_checkpoint_path ='exp/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
460
+ load_models(
461
+ distributed_run_manager,
462
+ distributed_run_manager.net,
463
+ args.dyn_checkpoint_path,
464
+ )
465
+ distributed_run_manager.write_log(
466
+ "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
467
+ % validate(distributed_run_manager, is_test=True, **validate_func_dict),
468
+ "valid",
469
+ )
470
+ else:
471
+ assert args.resume
472
+ train_elastic_width_mult (train,distributed_run_manager,args,validate_func_dict)
473
+
474
+
475
+
476
+ elif args.task == "kernel":
477
+ validate_func_dict["ks_list"] = sorted(args.ks_list)
478
+ if distributed_run_manager.start_epoch == 0:
479
+ if args.robust_mode:
480
+ args.dyn_checkpoint_path ='exp/robust/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
481
+ else:
482
+ args.dyn_checkpoint_path ='exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
483
+ load_models(
484
+ distributed_run_manager,
485
+ distributed_run_manager.net,
486
+ args.dyn_checkpoint_path,
487
+ )
488
+ distributed_run_manager.write_log(
489
+ "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
490
+ % validate(distributed_run_manager, is_test=True, **validate_func_dict),
491
+ "valid",
492
+ )
493
+ else:
494
+ assert args.resume
495
+ train(
496
+ distributed_run_manager,
497
+ args,
498
+ lambda _run_manager, epoch, is_test: validate(
499
+ _run_manager, epoch, is_test, **validate_func_dict
500
+ ),
501
+ )
502
+ elif args.task == "depth":
503
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
504
+ train_elastic_depth,
505
+ )
506
+ if args.robust_mode:
507
+ if args.model_name =="ResNet50":
508
+ if args.phase == 1:
509
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
510
+ else:
511
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
512
+ else:
513
+ if args.phase == 1:
514
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
515
+ else:
516
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
517
+ else :
518
+ if args.model_name =="ResNet50":
519
+ if args.phase == 1:
520
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
521
+ else:
522
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
523
+ else:
524
+ if args.phase == 1:
525
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
526
+ else:
527
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
528
+ train_elastic_depth(train, distributed_run_manager, args, validate_func_dict)
529
+ elif args.task == "expand":
530
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
531
+ train_elastic_expand,
532
+ )
533
+ if args.robust_mode :
534
+ if args.model_name =="ResNet50":
535
+ if args.phase == 1:
536
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
537
+ else:
538
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
539
+ else:
540
+ if args.phase == 1:
541
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
542
+ else:
543
+ args.dyn_checkpoint_path = "exp/robust/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
544
+ else:
545
+ if args.model_name =="ResNet50":
546
+ if args.phase == 1:
547
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
548
+ else:
549
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
550
+ else:
551
+ if args.phase == 1:
552
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
553
+ else:
554
+ args.dyn_checkpoint_path = "exp/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
555
+
556
+ train_elastic_expand(train, distributed_run_manager, args, validate_func_dict)
557
+ else:
558
+ raise NotImplementedError
train_ofa_net_WPS.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import argparse
6
+ import numpy as np
7
+ import os
8
+ import random
9
+
10
+ # using for distributed training
11
+ import horovod.torch as hvd
12
+ import torch
13
+
14
+
15
+ from proard.classification.elastic_nn.modules.dynamic_op import (
16
+ DynamicSeparableConv2d,
17
+ )
18
+ from proard.classification.elastic_nn.networks import DYNMobileNetV3,DYNProxylessNASNets,DYNResNets,DYNProxylessNASNets_Cifar,DYNMobileNetV3_Cifar,DYNResNets_Cifar
19
+ from proard.classification.run_manager import DistributedClassificationRunConfig
20
+ from proard.classification.run_manager.distributed_run_manager import (
21
+ DistributedRunManager
22
+ )
23
+ from proard.utils import download_url, MyRandomResizedCrop
24
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
25
+ load_models,
26
+ )
27
+
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument(
30
+ "--task",
31
+ type=str,
32
+ default="expand",
33
+ choices=[
34
+ "kernel", # for architecture except ResNet
35
+ "depth",
36
+ "expand",
37
+ "width", # only for ResNet
38
+ ],
39
+ )
40
+ parser.add_argument("--phase", type=int, default=2, choices=[1, 2])
41
+ parser.add_argument("--resume", action="store_true")
42
+ parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet"])
43
+ parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "cifar100", "imagenet"])
44
+ parser.add_argument("--robust_mode", type=bool, default=True)
45
+ parser.add_argument("--epsilon", type=float, default=0.031)
46
+ parser.add_argument("--num_steps", type=int, default=10)
47
+ parser.add_argument("--step_size", type=float, default=0.0078)
48
+ parser.add_argument("--clip_min", type=int, default=0)
49
+ parser.add_argument("--clip_max", type=int, default=1)
50
+ parser.add_argument("--const_init", type=bool, default=False)
51
+ parser.add_argument("--beta", type=float, default=6.0)
52
+ parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"])
53
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
54
+ parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"])
55
+ parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"])
56
+ parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce'])
57
+ args = parser.parse_args()
58
+ if args.model_name == "ResNet50":
59
+ args.ks_list = "3"
60
+ if args.task == "width":
61
+ if args.robust_mode:
62
+ args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2width"
63
+ else:
64
+ args.path = "exp/WPS"+ args.dataset + '/' +args.model_name +'/' + args.train_criterion +"/normal2width"
65
+ args.dynamic_batch_size = 1
66
+ args.n_epochs = 120
67
+ args.base_lr = 3e-2
68
+ args.warmup_epochs = 5
69
+ args.warmup_lr = -1
70
+ args.width_mult_list = "0.65,0.8,1.0"
71
+ args.expand_list = "0.35"
72
+ args.depth_list = "2"
73
+ elif args.task == "depth":
74
+ if args.robust_mode:
75
+ args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
76
+ else:
77
+ args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase%d" % args.phase
78
+ args.dynamic_batch_size = 2
79
+ if args.phase == 1:
80
+ args.n_epochs = 25
81
+ args.base_lr = 2.5e-3
82
+ args.warmup_epochs = 0
83
+ args.warmup_lr = -1
84
+ args.width_mult_list = "0.65,0.8,1.0"
85
+ args.expand_list ="0.35"
86
+ args.depth_list = "1,2"
87
+ else:
88
+ args.n_epochs = 120
89
+ args.base_lr = 7.5e-3
90
+ args.warmup_epochs = 5
91
+ args.warmup_lr = -1
92
+ args.width_mult_list = "0.65,0.8,1.0"
93
+ args.expand_list = "0.35"
94
+ args.depth_list = "0,1,2"
95
+ elif args.task == "expand":
96
+ if args.robust_mode :
97
+ args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
98
+ else:
99
+ args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase%d" % args.phase
100
+ args.dynamic_batch_size = 4
101
+ if args.phase == 1:
102
+ args.n_epochs = 25
103
+ args.base_lr = 2.5e-3
104
+ args.warmup_epochs = 0
105
+ args.warmup_lr = -1
106
+ args.width_mult_list = "0.65,0.8,1.0"
107
+ args.expand_list = "0.25,0.35"
108
+ args.depth_list = "0,1,2"
109
+ else:
110
+ args.n_epochs = 120
111
+ args.base_lr = 7.5e-3
112
+ args.warmup_epochs = 5
113
+ args.warmup_lr = -1
114
+ args.width_mult_list = "0.65,0.8,1.0"
115
+ args.expand_list = "0.2,0.25,0.35"
116
+ args.depth_list = "0,1,2"
117
+ else:
118
+ raise NotImplementedError
119
+ else:
120
+ args.width_mult_list = "1.0"
121
+ if args.task == "kernel":
122
+ if args.robust_mode:
123
+ args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
124
+ else:
125
+ args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel"
126
+ args.dynamic_batch_size = 1
127
+ args.n_epochs = 120
128
+ args.base_lr = 3e-2
129
+ args.warmup_epochs = 5
130
+ args.warmup_lr = -1
131
+ args.ks_list = "3,5,7"
132
+ args.expand_list = "6"
133
+ args.depth_list = "4"
134
+ elif args.task == "depth":
135
+ if args.robust_mode :
136
+ args.path = "exp/robust/WPS/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
137
+ else:
138
+ args.path = "exp/WPS/"+args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase%d" % args.phase
139
+ args.dynamic_batch_size = 2
140
+ if args.phase == 1:
141
+ args.n_epochs = 25
142
+ args.base_lr = 2.5e-3
143
+ args.warmup_epochs = 0
144
+ args.warmup_lr = -1
145
+ args.ks_list = "3,5,7"
146
+ args.expand_list = "6"
147
+ args.depth_list = "3,4"
148
+ else:
149
+ args.n_epochs = 120
150
+ args.base_lr = 7.5e-3
151
+ args.warmup_epochs = 5
152
+ args.warmup_lr = -1
153
+ args.ks_list = "3,5,7"
154
+ args.expand_list = "6"
155
+ args.depth_list = "2,3,4"
156
+ elif args.task == "expand":
157
+ if args.robust_mode:
158
+ args.path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase%d" % args.phase
159
+ else:
160
+ args.path = "exp/WPS/"+ args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/kernel_depth2kernel_depth_width/phase%d" % args.phase
161
+ args.dynamic_batch_size = 4
162
+ if args.phase == 1:
163
+ args.n_epochs = 25
164
+ args.base_lr = 2.5e-3
165
+ args.warmup_epochs = 0
166
+ args.warmup_lr = -1
167
+ args.ks_list = "3,5,7"
168
+ args.expand_list = "4,6"
169
+ args.depth_list = "2,3,4"
170
+ else:
171
+ args.n_epochs = 120
172
+ args.base_lr = 7.5e-3
173
+ args.warmup_epochs = 5
174
+ args.warmup_lr = -1
175
+ args.ks_list = "3,5,7"
176
+ args.expand_list = "3,4,6"
177
+ args.depth_list = "2,3,4"
178
+ else:
179
+ raise NotImplementedError
180
+ args.manual_seed = 0
181
+
182
+ args.lr_schedule_type = "cosine"
183
+
184
+ args.base_batch_size = 64
185
+ args.valid_size = 64
186
+
187
+ args.opt_type = "sgd"
188
+ args.momentum = 0.9
189
+ args.no_nesterov = False
190
+ args.weight_decay = 3e-5
191
+ args.label_smoothing = 0.1
192
+ args.no_decay_keys = "bn#bias"
193
+ args.fp16_allreduce = False
194
+
195
+ args.model_init = "he_fout"
196
+ args.validation_frequency = 1
197
+ args.print_frequency = 10
198
+
199
+ args.n_worker = 8
200
+ args.resize_scale = 0.08
201
+ args.distort_color = "tf"
202
+ if args.dataset == "imagenet":
203
+ args.image_size = "128,160,192,224"
204
+ else:
205
+ args.image_size = "32"
206
+ args.continuous_size = True
207
+ args.not_sync_distributed_image_size = False
208
+
209
+ args.bn_momentum = 0.1
210
+ args.bn_eps = 1e-5
211
+ args.dropout = 0.1
212
+ args.base_stage_width = "google"
213
+
214
+
215
+ args.dy_conv_scaling_mode = -1
216
+ args.independent_distributed_sampling = False
217
+
218
+ args.kd_ratio = 1.0
219
+ args.kd_type = "ce"
220
+
221
+
222
+ if __name__ == "__main__":
223
+ os.makedirs(args.path, exist_ok=True)
224
+
225
+ # Initialize Horovod
226
+ hvd.init()
227
+ # Pin GPU to be used to process local rank (one GPU per process)
228
+ torch.cuda.set_device(hvd.local_rank())
229
+ if args.robust_mode:
230
+ args.teacher_path = 'exp/robust/teacher/' + args.dataset + '/' + args.model_name + '/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
231
+ else:
232
+ args.teacher_path = 'exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
233
+ num_gpus = hvd.size()
234
+
235
+ torch.manual_seed(args.manual_seed)
236
+ torch.cuda.manual_seed_all(args.manual_seed)
237
+ np.random.seed(args.manual_seed)
238
+ random.seed(args.manual_seed)
239
+
240
+ # image size
241
+ args.image_size = [int(img_size) for img_size in args.image_size.split(",")]
242
+ if len(args.image_size) == 1:
243
+ args.image_size = args.image_size[0]
244
+ MyRandomResizedCrop.CONTINUOUS = args.continuous_size
245
+ MyRandomResizedCrop.SYNC_DISTRIBUTED = not args.not_sync_distributed_image_size
246
+
247
+ # build run config from args
248
+ args.lr_schedule_param = None
249
+ args.opt_param = {
250
+ "momentum": args.momentum,
251
+ "nesterov": not args.no_nesterov,
252
+ }
253
+ args.init_lr = args.base_lr * num_gpus # linearly rescale the learning rate
254
+ if args.warmup_lr < 0:
255
+ args.warmup_lr = args.base_lr
256
+ args.train_batch_size = args.base_batch_size
257
+ args.test_batch_size = args.base_batch_size * 4
258
+ run_config = DistributedClassificationRunConfig(
259
+ **args.__dict__, num_replicas=num_gpus, rank=hvd.rank()
260
+ )
261
+
262
+ # print run config information
263
+ if hvd.rank() == 0:
264
+ print("Run config:")
265
+ for k, v in run_config.config.items():
266
+ print("\t%s: %s" % (k, v))
267
+
268
+ if args.dy_conv_scaling_mode == -1:
269
+ args.dy_conv_scaling_mode = None
270
+ DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode
271
+
272
+ # build net from args
273
+ args.width_mult_list = [
274
+ float(width_mult) for width_mult in args.width_mult_list.split(",")
275
+ ]
276
+ args.ks_list = [int(ks) for ks in args.ks_list.split(",")]
277
+ if args.model_name == "ResNet50":
278
+ args.expand_list = [float(e) for e in args.expand_list.split(",")]
279
+ else:
280
+ args.expand_list = [int(e) for e in args.expand_list.split(",")]
281
+ args.depth_list = [int(d) for d in args.depth_list.split(",")]
282
+
283
+ args.width_mult_list = (
284
+ args.width_mult_list[0]
285
+ if len(args.width_mult_list) == 1
286
+ else args.width_mult_list
287
+ )
288
+
289
+ if args.model_name == "ResNet50":
290
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
291
+ net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes,
292
+ bn_param=(args.bn_momentum, args.bn_eps),
293
+ dropout_rate=args.dropout,
294
+ depth_list=args.depth_list,
295
+ expand_ratio_list=args.expand_list,
296
+ width_mult_list=args.width_mult_list,)
297
+ else:
298
+ net = DYNResNets( n_classes=run_config.data_provider.n_classes,
299
+ bn_param=(args.bn_momentum, args.bn_eps),
300
+ dropout_rate=args.dropout,
301
+ depth_list=args.depth_list,
302
+ expand_ratio_list=args.expand_list,
303
+ width_mult_list=args.width_mult_list,)
304
+ elif args.model_name == "MBV3":
305
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
306
+ net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
307
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
308
+ else:
309
+ net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
310
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
311
+ elif args.model_name == "ProxylessNASNet":
312
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
313
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
314
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
315
+ else:
316
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
317
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list)
318
+ elif args.model_name == "MBV2":
319
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
320
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
321
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
322
+ else:
323
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
324
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list,width_mult=args.width_mult_list,base_stage_width=args.base_stage_width)
325
+ else:
326
+ raise NotImplementedError
327
+ # teacher model
328
+ if args.kd_ratio > 0:
329
+
330
+ if args.model_name =="ResNet50":
331
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
332
+ args.teacher_model = DYNResNets_Cifar(
333
+ n_classes=run_config.data_provider.n_classes,
334
+ bn_param=(args.bn_momentum, args.bn_eps),
335
+ dropout_rate=args.dropout,
336
+ depth_list=[2],
337
+ expand_ratio_list=[0.35],
338
+ width_mult_list=[1.0],
339
+ )
340
+ else:
341
+ args.teacher_model = DYNResNets(
342
+ n_classes=run_config.data_provider.n_classes,
343
+ bn_param=(args.bn_momentum, args.bn_eps),
344
+ dropout_rate=args.dropout,
345
+ depth_list=[2],
346
+ expand_ratio_list=[0.35],
347
+ width_mult_list=[1.0],
348
+ )
349
+ elif args.model_name =="MBV3":
350
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
351
+ args.teacher_model = DYNMobileNetV3_Cifar(
352
+ n_classes=run_config.data_provider.n_classes,
353
+ bn_param=(args.bn_momentum, args.bn_eps),
354
+ dropout_rate=0,
355
+ width_mult=1.0,
356
+ ks_list=[7],
357
+ expand_ratio_list=[6],
358
+ depth_list=[4]
359
+ )
360
+ else:
361
+ args.teacher_model = DYNMobileNetV3(
362
+ n_classes=run_config.data_provider.n_classes,
363
+ bn_param=(args.bn_momentum, args.bn_eps),
364
+ dropout_rate=0,
365
+ width_mult=1.0,
366
+ ks_list=[7],
367
+ expand_ratio_list=[6],
368
+ depth_list=[4]
369
+ )
370
+ elif args.model_name == "ProxylessNASNet":
371
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
372
+ args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
373
+ bn_param=(args.bn_momentum, args.bn_eps),
374
+ dropout_rate=0,
375
+ width_mult=1.0,
376
+ ks_list=[7],
377
+ expand_ratio_list=[6],
378
+ depth_list=[4])
379
+ else:
380
+ args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
381
+ bn_param=(args.bn_momentum, args.bn_eps),
382
+ dropout_rate=0,
383
+ width_mult=1.0,
384
+ ks_list=[7],
385
+ expand_ratio_list=[6],
386
+ depth_list=[4])
387
+ elif args.model_name == "MBV2":
388
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
389
+ args.teacher_model = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,
390
+ bn_param=(args.bn_momentum, args.bn_eps),
391
+ dropout_rate=0,
392
+ width_mult=1.0,
393
+ ks_list=[7],
394
+ expand_ratio_list=[6],
395
+ depth_list=[4],base_stage_width=args.base_stage_width)
396
+ else:
397
+ args.teacher_model = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,
398
+ bn_param=(args.bn_momentum, args.bn_eps),
399
+ dropout_rate=0,
400
+ width_mult=1.0,
401
+ ks_list=[7],
402
+ expand_ratio_list=[6],
403
+ depth_list=[4],base_stage_width=args.base_stage_width)
404
+
405
+ args.teacher_model.cuda()
406
+
407
+ """ Distributed RunManager """
408
+ # Horovod: (optional) compression algorithm.
409
+ compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
410
+ distributed_run_manager = DistributedRunManager(
411
+ args.path,
412
+ net,
413
+ run_config,
414
+ compression,
415
+ backward_steps=args.dynamic_batch_size,
416
+ is_root=(hvd.rank() == 0),
417
+ )
418
+ distributed_run_manager.save_config()
419
+ # hvd broadcast
420
+ distributed_run_manager.broadcast()
421
+
422
+ # load teacher net weights
423
+ if args.kd_ratio > 0:
424
+ load_models(
425
+ distributed_run_manager, args.teacher_model, model_path=args.teacher_path
426
+ )
427
+
428
+ # training
429
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
430
+ validate,
431
+ train,
432
+ )
433
+ if args.model_name =="ResNet50":
434
+ validate_func_dict = {
435
+ "image_size_list": {224 if args.dataset == "imagenet" else 32}
436
+ if isinstance(args.image_size, int)
437
+ else sorted({160, 224}),
438
+ "width_mult_list": sorted({min(args.width_mult_list), max(args.width_mult_list)}),
439
+ "expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
440
+ "depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
441
+ }
442
+ else:
443
+ validate_func_dict = {
444
+ "image_size_list": {224 if args.dataset == "imagenet" else 32}
445
+ if isinstance(args.image_size, int)
446
+ else sorted({160, 224}),
447
+ "width_mult_list": [1.0],
448
+ "ks_list": sorted({min(args.ks_list), max(args.ks_list)}),
449
+ "expand_ratio_list": sorted({min(args.expand_list), max(args.expand_list)}),
450
+ "depth_list": sorted({min(net.depth_list), max(net.depth_list)}),
451
+ }
452
+
453
+ if args.task == "width":
454
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
455
+ train_elastic_width_mult,
456
+ )
457
+ if distributed_run_manager.start_epoch == 0:
458
+ if args.robust_mode:
459
+ args.dyn_checkpoint_path ='exp/robust/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
460
+ else:
461
+ args.dyn_checkpoint_path ='exp/teacher/' +args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
462
+ load_models(
463
+ distributed_run_manager,
464
+ distributed_run_manager.net,
465
+ args.dyn_checkpoint_path,
466
+ )
467
+ distributed_run_manager.write_log(
468
+ "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
469
+ % validate(distributed_run_manager, is_test=True, **validate_func_dict),
470
+ "valid",
471
+ )
472
+ else:
473
+ assert args.resume
474
+ train(distributed_run_manager,args,lambda _run_manager, epoch, is_test: validate(
475
+ _run_manager, epoch, is_test, **validate_func_dict
476
+ ),)
477
+
478
+
479
+
480
+ elif args.task == "kernel":
481
+ validate_func_dict["ks_list"] = sorted(args.ks_list)
482
+ if distributed_run_manager.start_epoch == 0:
483
+ if args.robust_mode:
484
+ args.dyn_checkpoint_path ='exp/robust/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
485
+ else:
486
+ args.dyn_checkpoint_path ='exp/teacher/' + args.dataset + '/' + args.model_name +'/' + args.train_criterion + "/checkpoint/model_best.pth.tar"
487
+ load_models(
488
+ distributed_run_manager,
489
+ distributed_run_manager.net,
490
+ args.dyn_checkpoint_path,
491
+ )
492
+ distributed_run_manager.write_log(
493
+ "%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%s"
494
+ % validate(distributed_run_manager, is_test=True, **validate_func_dict),
495
+ "valid",
496
+ )
497
+ else:
498
+ assert args.resume
499
+ train(
500
+ distributed_run_manager,
501
+ args,
502
+ lambda _run_manager, epoch, is_test: validate(
503
+ _run_manager, epoch, is_test, **validate_func_dict
504
+ ),
505
+ )
506
+ elif args.task == "depth":
507
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
508
+ train_elastic_depth,
509
+ )
510
+ if args.robust_mode:
511
+ if args.model_name =="ResNet50":
512
+ if args.phase == 1:
513
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
514
+ else:
515
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
516
+ else:
517
+ if args.phase == 1:
518
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
519
+ else:
520
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
521
+ else :
522
+ if args.model_name =="ResNet50":
523
+ if args.phase == 1:
524
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/normal2width" +"/checkpoint/model_best.pth.tar"
525
+ else:
526
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/width2width_depth/phase1" + "/checkpoint/model_best.pth.tar"
527
+ else:
528
+ if args.phase == 1:
529
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/normal2kernel" +"/checkpoint/model_best.pth.tar"
530
+ else:
531
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/' + args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase1" + "/checkpoint/model_best.pth.tar"
532
+ train(
533
+ distributed_run_manager,
534
+ args,
535
+ lambda _run_manager, epoch, is_test: validate(
536
+ _run_manager, epoch, is_test, **validate_func_dict
537
+ ),)
538
+ elif args.task == "expand":
539
+ from proard.classification.elastic_nn.training.progressive_shrinking import (
540
+ train_elastic_expand,
541
+ )
542
+ if args.robust_mode :
543
+ if args.model_name =="ResNet50":
544
+ if args.phase == 1:
545
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
546
+ else:
547
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
548
+ else:
549
+ if args.phase == 1:
550
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
551
+ else:
552
+ args.dyn_checkpoint_path = "exp/robust/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
553
+ else:
554
+ if args.model_name =="ResNet50":
555
+ if args.phase == 1:
556
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width2width_depth/phase2" + "/checkpoint/model_best.pth.tar"
557
+ else:
558
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/width_depth2width_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
559
+ else:
560
+ if args.phase == 1:
561
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel2kernel_depth/phase2" + "/checkpoint/model_best.pth.tar"
562
+ else:
563
+ args.dyn_checkpoint_path = "exp/WPS/"+ args.dataset + '/'+ args.model_name +'/' + args.train_criterion +"/kernel_depth2kernel_depth_width/phase1" + "/checkpoint/model_best.pth.tar"
564
+
565
+ train(
566
+ distributed_run_manager,
567
+ args,
568
+ lambda _run_manager, epoch, is_test: validate(
569
+ _run_manager, epoch, is_test, **validate_func_dict
570
+ ),)
571
+ else:
572
+ raise NotImplementedError
train_teacher_net.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Once for All: Train One Network and Specialize it for Efficient Deployment
2
+ # Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
3
+ # International Conference on Learning Representations (ICLR), 2020.
4
+
5
+ import argparse
6
+ import numpy as np
7
+ import os
8
+ import random
9
+ # using for distributed training
10
+ import horovod.torch as hvd
11
+ import torch
12
+
13
+ from proard.classification.elastic_nn.modules.dynamic_op import (
14
+ DynamicSeparableConv2d,
15
+ )
16
+ from proard.classification.elastic_nn.networks import DYNResNets,DYNMobileNetV3,DYNProxylessNASNets,DYNMobileNetV3_Cifar,DYNResNets_Cifar,DYNProxylessNASNets_Cifar
17
+ from proard.classification.run_manager import DistributedClassificationRunConfig
18
+ from proard.classification.networks import WideResNet
19
+ from proard.classification.run_manager import DistributedRunManager
20
+
21
+
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--model_name", type=str, default="MBV2", choices=["ResNet50", "MBV3", "ProxylessNASNet","WideResNet","MBV2"])
24
+ parser.add_argument("--teacher_model_name", type=str, default="WideResNet", choices=["WideResNet"])
25
+ parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar10", "cifar100", "imagenet"])
26
+ parser.add_argument("--robust_mode", type=bool, default=True)
27
+ parser.add_argument("--epsilon", type=float, default=0.031)
28
+ parser.add_argument("--num_steps", type=int, default=10)
29
+ parser.add_argument("--step_size", type=float, default=0.0078)
30
+ parser.add_argument("--clip_min", type=int, default=0)
31
+ parser.add_argument("--clip_max", type=int, default=1)
32
+ parser.add_argument("--const_init", type=bool, default=False)
33
+ parser.add_argument("--beta", type=float, default=6.0)
34
+ parser.add_argument("--distance", type=str, default="l_inf",choices=["l_inf","l2"])
35
+ parser.add_argument("--train_criterion", type=str, default="trades",choices=["trades","sat","mart","hat"])
36
+ parser.add_argument("--test_criterion", type=str, default="ce",choices=["ce"])
37
+ parser.add_argument("--kd_criterion", type=str, default="rslad",choices=["ard","rslad","adaad"])
38
+ parser.add_argument("--attack_type", type=str, default="linf-pgd",choices=['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce'])
39
+
40
+
41
+ args = parser.parse_args()
42
+ if args.robust_mode:
43
+ args.path = 'exp/robust/teacher/' + args.dataset + "/" + args.model_name + '/' + args.train_criterion
44
+ else:
45
+ args.path = 'exp/teacher/' + args.dataset + "/" + args.model_name
46
+ args.n_epochs = 120
47
+ args.base_lr = 0.1
48
+ args.warmup_epochs = 5
49
+ args.warmup_lr = -1
50
+ args.manual_seed = 0
51
+ args.lr_schedule_type = "cosine"
52
+ args.base_batch_size = 128
53
+ args.valid_size = None
54
+ args.opt_type = "sgd"
55
+ args.momentum = 0.9
56
+ args.no_nesterov = False
57
+ args.weight_decay = 2e-4
58
+ args.label_smoothing = 0.0
59
+ args.no_decay_keys = "bn#bias"
60
+ args.fp16_allreduce = False
61
+ args.model_init = "he_fout"
62
+ args.validation_frequency = 1
63
+ args.print_frequency = 10
64
+ args.n_worker = 32
65
+ if args.dataset =="imagenet":
66
+ args.image_size = "224"
67
+ else:
68
+ args.image_size = "32"
69
+ args.continuous_size = True
70
+ args.not_sync_distributed_image_size = False
71
+ args.bn_momentum = 0.1
72
+ args.bn_eps = 1e-5
73
+ args.dropout = 0.0
74
+ args.base_stage_width = "google"
75
+ ###### Parameters for MBV3, ProxylessNet, and MBV2
76
+ if args.model_name != "ResNet50":
77
+ args.ks_list = '7'
78
+ args.expand_list = '6'
79
+ args.depth_list = '4'
80
+ args.width_mult_list = "1.0"
81
+ else:
82
+ ###### Parameters for ResNet50
83
+ args.ks_list = "3"
84
+ args.expand_list = "0.35"
85
+ args.depth_list = "2"
86
+ args.width_mult_list = "1.0"
87
+ ########################################
88
+ args.dy_conv_scaling_mode = 1
89
+ args.independent_distributed_sampling = False
90
+ args.kd_ratio = 0.0
91
+ args.kd_type = "ce"
92
+ args.dynamic_batch_size = 1
93
+ args.num_gpus = 4
94
+ if __name__ == "__main__":
95
+ os.makedirs(args.path, exist_ok=True)
96
+
97
+ # Initialize Horovod
98
+ hvd.init()
99
+ # Pin GPU to be used to process local rank (one GPU per process)
100
+ torch.cuda.set_device(hvd.local_rank())
101
+
102
+ num_gpus = hvd.size()
103
+ torch.manual_seed(args.manual_seed)
104
+ torch.cuda.manual_seed_all(args.manual_seed)
105
+ np.random.seed(args.manual_seed)
106
+ random.seed(args.manual_seed)
107
+
108
+ # image size
109
+ args.image_size = [int(img_size) for img_size in args.image_size.split(",")]
110
+ if len(args.image_size) == 1:
111
+ args.image_size = args.image_size[0]
112
+
113
+ # build run config from args
114
+ args.lr_schedule_param = None
115
+ args.opt_param = {
116
+ "momentum": args.momentum,
117
+ "nesterov": not args.no_nesterov,
118
+ }
119
+ args.init_lr = args.base_lr * num_gpus # linearly rescale the learning rate
120
+ if args.warmup_lr < 0:
121
+ args.warmup_lr = args.base_lr
122
+ args.train_batch_size = args.base_batch_size
123
+ args.test_batch_size = args.base_batch_size
124
+ print(args.__dict__)
125
+ run_config = DistributedClassificationRunConfig(
126
+ **args.__dict__,num_replicas=num_gpus, rank=hvd.rank()
127
+ )
128
+
129
+ # print run config information
130
+ if hvd.rank() == 0:
131
+ print("Run config:")
132
+ for k, v in run_config.config.items():
133
+ print("\t%s: %s" % (k, v))
134
+
135
+ if args.dy_conv_scaling_mode == -1:
136
+ args.dy_conv_scaling_mode = None
137
+ DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = args.dy_conv_scaling_mode
138
+
139
+ # build net from args
140
+ args.width_mult_list = [
141
+ float(width_mult) for width_mult in args.width_mult_list.split(",")
142
+ ]
143
+ args.ks_list = [int(ks) for ks in args.ks_list.split(",")]
144
+ args.expand_list = [float(e) for e in args.expand_list.split(",")]
145
+ args.depth_list = [int(d) for d in args.depth_list.split(",")]
146
+
147
+ args.width_mult_list = (
148
+ args.width_mult_list[0]
149
+ if len(args.width_mult_list) == 1
150
+ else args.width_mult_list
151
+ )
152
+ if args.model_name == "ResNet50":
153
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
154
+ # net = ResNet50_Cifar(n_classes=run_config.data_provider.n_classes)
155
+ net = DYNResNets_Cifar( n_classes=run_config.data_provider.n_classes,
156
+ bn_param=(args.bn_momentum, args.bn_eps),
157
+ dropout_rate=args.dropout,
158
+ depth_list=args.depth_list,
159
+ expand_ratio_list=args.expand_list,
160
+ width_mult_list=args.width_mult_list,)
161
+ else:
162
+ net = DYNResNets( n_classes=run_config.data_provider.n_classes,
163
+ bn_param=(args.bn_momentum, args.bn_eps),
164
+ dropout_rate=args.dropout,
165
+ depth_list=args.depth_list,
166
+ expand_ratio_list=args.expand_list,
167
+ width_mult_list=args.width_mult_list,)
168
+ elif args.model_name == "MBV3":
169
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
170
+ net = DYNMobileNetV3_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
171
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
172
+ else:
173
+ net = DYNMobileNetV3(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
174
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
175
+ elif args.model_name == "ProxylessNASNet":
176
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
177
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
178
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
179
+ else:
180
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),
181
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
182
+
183
+ elif args.model_name == "MBV2":
184
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
185
+ net = DYNProxylessNASNets_Cifar(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),base_stage_width=args.base_stage_width,
186
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
187
+ else:
188
+ net = DYNProxylessNASNets(n_classes=run_config.data_provider.n_classes,bn_param=(args.bn_momentum,args.bn_eps),base_stage_width=args.base_stage_width,
189
+ dropout_rate= args.dropout, ks_list=args.ks_list , expand_ratio_list= args.expand_list , depth_list= args.depth_list)
190
+ else:
191
+ raise NotImplementedError
192
+ if args.teacher_model_name == "WideResNet":
193
+ if args.dataset == "cifar10" or args.dataset == "cifar100":
194
+ net = WideResNet(num_classes=run_config.data_provider.n_classes)
195
+ else:
196
+ raise NotImplementedError
197
+ else:
198
+ raise NotImplementedError
199
+ args.teacher_model = None #'exp/teacher/' + args.dataset + "/" + "WideResNet"
200
+
201
+ """ Distributed RunManager """
202
+ #Horovod: (optional) compression algorithm.
203
+ compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
204
+ distributed_run_manager = DistributedRunManager(
205
+ args.path,
206
+ net,
207
+ run_config,
208
+ compression,
209
+ backward_steps=args.dynamic_batch_size,
210
+ is_root=(hvd.rank() == 0),
211
+ )
212
+ distributed_run_manager.save_config()
213
+ distributed_run_manager.broadcast()
214
+
215
+
216
+ distributed_run_manager.train(args)