andreysher commited on
Commit
4d679c2
1 Parent(s): 632e862

imagenet-benchmark

Browse files
MobileNetV2/MobileNetV2-ENOT.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fae5b0822282cce7cec83d63b96af7bd12deae8e8371083b28a9bc6002e08a7d
3
+ size 10682115
MobileNetV2/MobileNetV2-ENOT.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d39fa80cba1d431eea3009c7ae0bf506fb7e6c6c97853994329ebc03a1fc40e
3
+ size 32641690
README.md CHANGED
@@ -1,3 +1,61 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ datasets:
4
+ - imagenet-1k
5
+ library_name: torchvision
6
+ pipeline_tag: image-classification
7
+ tags:
8
+ - onnx
9
+ - ENOT-AutoDL
10
  ---
11
+
12
+ # ENOT-AutoDL pruning benchmark on ImageNet-1k
13
+
14
+ This repository contains models accelerated with [ENOT-AutoDL](https://pypi.org/project/enot-autodl/) framework.
15
+ Models from [Torchvision](https://pytorch.org/vision/stable/models.html) are used as a baseline.
16
+ Evaluation code is also based on Torchvision references.
17
+
18
+ ## ResNet-50
19
+
20
+ | Model | Latency (MMACs) | Accuracy (%) |
21
+ | ------------------------- | :---------------: | :-------------: |
22
+ | **ResNet-50 Torchvision** | 4144.854 | 76.144 |
23
+ | **ResNet-50 ENOT (x2)** | 2057.615 (x2.014) | 75.482 (-0.662) |
24
+ | **ResNet-50 ENOT (x4)** | 867.943 (x4.775) | 73.576 (-2.568) |
25
+
26
+ ## ViT-B/32
27
+
28
+ | Model | Latency (MMACs) | Accuracy (%) |
29
+ | ------------------------ | :--------------: | :-------------: |
30
+ | **ViT-B/32 Torchvision** | 4413.986 | 75.912 |
31
+ | **ViT-B/32 ENOT** | 492.232 (x8.967) | 73.718 (-2.194) |
32
+
33
+ ## MobileNetV2
34
+
35
+ | Model | Latency (MMACs) | Accuracy (%) |
36
+ | --------------------------- | :--------------: | :------------: |
37
+ | **MobileNetV2 Torchvision** | 334.227 | 71.878 |
38
+ | **MobileNetV2 ENOT** | 156.800 (x2.131) | 69.898 (-1.98) |
39
+
40
+ # Validation
41
+
42
+ To validate results, follow this steps:
43
+
44
+ 1. Install all required packages:
45
+ ```bash
46
+ pip install -r requrements.txt
47
+ ```
48
+ 1. Calculate model latency:
49
+ ```bash
50
+ python measure_mac.py --model-ckpt path/to/model.pth
51
+ ```
52
+ 1. Measure accuracy of ONNX model:
53
+ ```bash
54
+ python test.py --data-path path/to/imagenet --model-onnx path/to/model.onnx --batch-size 1
55
+ ```
56
+ 1. Measure accuracy of PyTorch (.pth) model:
57
+ ```bash
58
+ python test.py --data-path path/to/imagenet --model-ckpt path/to/model.pth
59
+ ```
60
+
61
+ If you want to book a demo, please [contact us](enot@enot.ai).
ResNet-50/ResNet50-ENOT-x2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f689ec182909427df72d390d425eb3b72618d4c40ae089b93b66ea14c6adf5f
3
+ size 50666788
ResNet-50/ResNet50-ENOT-x2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0b29e2ac563332d02274d6d656379d3b0957b91b7c8b6c1b4433657d74d6e68
3
+ size 101839301
ResNet-50/ResNet50-ENOT-x4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:387b705b1d83c844f513d7646f95138f8fcfb420e1ef0b5f8d7039e550c66b91
3
+ size 20850032
ResNet-50/ResNet50-ENOT-x4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a9b7d6ac9062b92da9b44f61ace8c62da76ce86fb0947fdb40fb449792e194a
3
+ size 62177349
ViT-B-32/ViT-B-32-ENOT.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63a4d0e19cfbeca9dca0b18aaf5c60b2d845c05a25e7641b954e90839efda63b
3
+ size 39430730
ViT-B-32/ViT-B-32-ENOT.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92a81cef913af4012215215400049317168939624f09d76e9043aee2342af356
3
+ size 157444613
measure_mac.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from fvcore.nn import FlopCountAnalysis
5
+
6
+
7
+ def get_args():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--model-ckpt", type=str)
10
+
11
+ return parser.parse_args()
12
+
13
+
14
+ def main():
15
+ args = get_args()
16
+
17
+ checkpoint = torch.load(args.model_ckpt, map_location="cpu")
18
+ model = checkpoint["model_ckpt"]
19
+ model.eval()
20
+
21
+ flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224)))
22
+ flops = flops.total()
23
+
24
+ print(f"MMACs = {flops/1e6}")
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ fvcore==0.1.5.post20221221
4
+ onnxruntime-gpu==1.15.1
5
+ onnx==1.13.1
test.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import onnxruntime
4
+ import torch
5
+ import torch.utils.data
6
+ import torchvision
7
+ from torch import nn
8
+ from torchvision.transforms.functional import InterpolationMode
9
+
10
+ import utils
11
+
12
+
13
+ def evaluate(
14
+ criterion,
15
+ data_loader,
16
+ device,
17
+ model=None,
18
+ model_onnx_path=None,
19
+ print_freq=100,
20
+ log_suffix="",
21
+ ):
22
+ if model_onnx_path:
23
+ session = onnxruntime.InferenceSession(
24
+ model_onnx_path, providers=["CPUExecutionProvider"]
25
+ )
26
+ input_name = session.get_inputs()[0].name
27
+
28
+ metric_logger = utils.MetricLogger(delimiter=" ")
29
+ header = f"Test: {log_suffix}"
30
+
31
+ num_processed_samples = 0
32
+ with torch.inference_mode():
33
+ for image, target in metric_logger.log_every(data_loader, print_freq, header):
34
+ target = target.to(device, non_blocking=True)
35
+ image = image.to(device)
36
+
37
+ if model_onnx_path:
38
+ # from torch to numpy (ort)
39
+ input_data = image.cpu().numpy()
40
+
41
+ output_data = session.run([], {input_name: input_data})[0]
42
+
43
+ # from numpy to torch
44
+ output = torch.from_numpy(output_data).to(device)
45
+ elif model:
46
+ output = model(image)
47
+
48
+ loss = criterion(output, target)
49
+
50
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
51
+ # FIXME need to take into account that the datasets
52
+ # could have been padded in distributed setup
53
+ batch_size = image.shape[0]
54
+ metric_logger.update(loss=loss.item())
55
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
56
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
57
+ num_processed_samples += batch_size
58
+ # gather the stats from all processes
59
+
60
+ metric_logger.synchronize_between_processes()
61
+
62
+ print(
63
+ f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}"
64
+ )
65
+ return metric_logger.acc1.global_avg
66
+
67
+
68
+ def load_data(valdir):
69
+ # Data loading code
70
+ print("Loading data")
71
+ interpolation = InterpolationMode("bilinear")
72
+
73
+ preprocessing = torchvision.transforms.Compose(
74
+ [
75
+ torchvision.transforms.Resize(256, interpolation=interpolation),
76
+ torchvision.transforms.CenterCrop(224),
77
+ torchvision.transforms.PILToTensor(),
78
+ torchvision.transforms.ConvertImageDtype(torch.float),
79
+ torchvision.transforms.Normalize(
80
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
81
+ ),
82
+ ]
83
+ )
84
+
85
+ dataset_test = torchvision.datasets.ImageFolder(
86
+ valdir,
87
+ preprocessing,
88
+ )
89
+
90
+ print("Creating data loaders")
91
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
92
+
93
+ return dataset_test, test_sampler
94
+
95
+
96
+ def main(args):
97
+ print(args)
98
+
99
+ if torch.cuda.is_available():
100
+ device = torch.device("cuda")
101
+ else:
102
+ device = torch.device("cpu")
103
+
104
+ val_dir = os.path.join(args.data_path, "val")
105
+ dataset_test, test_sampler = load_data(val_dir)
106
+
107
+ data_loader_test = torch.utils.data.DataLoader(
108
+ dataset_test,
109
+ batch_size=args.batch_size,
110
+ sampler=test_sampler,
111
+ num_workers=args.workers,
112
+ pin_memory=True,
113
+ )
114
+
115
+ print("Creating model")
116
+
117
+ criterion = nn.CrossEntropyLoss()
118
+
119
+ model = None
120
+ if args.model_ckpt:
121
+ checkpoint = torch.load(args.model_ckpt, map_location="cpu")
122
+ model = checkpoint["model_ckpt"]
123
+ if "model_ema" in checkpoint:
124
+ state_dict = {}
125
+ for key, value in checkpoint["model_ema"].items():
126
+ if not "module." in key:
127
+ continue
128
+ state_dict[key.replace("module.", "")] = value
129
+ model.load_state_dict(state_dict)
130
+ model = model.to(device)
131
+
132
+ accuracy = evaluate(
133
+ model=model,
134
+ model_onnx_path=args.model_onnx,
135
+ criterion=criterion,
136
+ data_loader=data_loader_test,
137
+ device=device,
138
+ )
139
+ print(f"Model accuracy is: {accuracy}")
140
+
141
+
142
+ def get_args_parser(add_help=True):
143
+ import argparse
144
+
145
+ parser = argparse.ArgumentParser(
146
+ description="PyTorch Classification Training", add_help=add_help
147
+ )
148
+
149
+ parser.add_argument(
150
+ "--data-path", default="datasets/imagenet", type=str, help="dataset path"
151
+ )
152
+ parser.add_argument(
153
+ "-b",
154
+ "--batch-size",
155
+ default=32,
156
+ type=int,
157
+ help="images per gpu, the total batch size is $NGPU x batch_size",
158
+ )
159
+ parser.add_argument(
160
+ "-j",
161
+ "--workers",
162
+ default=16,
163
+ type=int,
164
+ metavar="N",
165
+ help="number of data loading workers (default: 16)",
166
+ )
167
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
168
+ parser.add_argument(
169
+ "--model-onnx", default="", type=str, help="path of .onnx checkpoint"
170
+ )
171
+ parser.add_argument(
172
+ "--model-ckpt", default="", type=str, help="path of .pth checkpoint"
173
+ )
174
+
175
+ return parser
176
+
177
+
178
+ if __name__ == "__main__":
179
+ args = get_args_parser().parse_args()
180
+ main(args)
utils.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import time
3
+ from collections import defaultdict
4
+ from collections import deque
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+
10
+ class SmoothedValue:
11
+ """Track a series of values and provide access to smoothed values over a
12
+ window or the global series average."""
13
+
14
+ def __init__(self, window_size=20, fmt=None):
15
+ if fmt is None:
16
+ fmt = "{median:.4f} ({global_avg:.4f})"
17
+ self.deque = deque(maxlen=window_size)
18
+ self.total = 0.0
19
+ self.count = 0
20
+ self.fmt = fmt
21
+
22
+ def update(self, value, n=1):
23
+ self.deque.append(value)
24
+ self.count += n
25
+ self.total += value * n
26
+
27
+ def synchronize_between_processes(self):
28
+ """
29
+ Warning: does not synchronize the deque!
30
+ """
31
+ t = reduce_across_processes([self.count, self.total])
32
+ t = t.tolist()
33
+ self.count = int(t[0])
34
+ self.total = t[1]
35
+
36
+ @property
37
+ def median(self):
38
+ d = torch.tensor(list(self.deque))
39
+ return d.median().item()
40
+
41
+ @property
42
+ def avg(self):
43
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
44
+ return d.mean().item()
45
+
46
+ @property
47
+ def global_avg(self):
48
+ return self.total / self.count
49
+
50
+ @property
51
+ def max(self):
52
+ return max(self.deque)
53
+
54
+ @property
55
+ def value(self):
56
+ return self.deque[-1]
57
+
58
+ def __str__(self):
59
+ return self.fmt.format(
60
+ median=self.median,
61
+ avg=self.avg,
62
+ global_avg=self.global_avg,
63
+ max=self.max,
64
+ value=self.value,
65
+ )
66
+
67
+
68
+ class MetricLogger:
69
+ def __init__(self, delimiter="\t"):
70
+ self.meters = defaultdict(SmoothedValue)
71
+ self.delimiter = delimiter
72
+
73
+ def update(self, **kwargs):
74
+ for k, v in kwargs.items():
75
+ if isinstance(v, torch.Tensor):
76
+ v = v.item()
77
+ assert isinstance(v, (float, int))
78
+ self.meters[k].update(v)
79
+
80
+ def __getattr__(self, attr):
81
+ if attr in self.meters:
82
+ return self.meters[attr]
83
+ if attr in self.__dict__:
84
+ return self.__dict__[attr]
85
+ raise AttributeError(
86
+ f"'{type(self).__name__}' object has no attribute '{attr}'"
87
+ )
88
+
89
+ def __str__(self):
90
+ loss_str = []
91
+ for name, meter in self.meters.items():
92
+ loss_str.append(f"{name}: {str(meter)}")
93
+ return self.delimiter.join(loss_str)
94
+
95
+ def synchronize_between_processes(self):
96
+ for meter in self.meters.values():
97
+ meter.synchronize_between_processes()
98
+
99
+ def add_meter(self, name, meter):
100
+ self.meters[name] = meter
101
+
102
+ def log_every(self, iterable, print_freq, header=None):
103
+ i = 0
104
+ if not header:
105
+ header = ""
106
+ start_time = time.time()
107
+ end = time.time()
108
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
109
+ data_time = SmoothedValue(fmt="{avg:.4f}")
110
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
111
+ if torch.cuda.is_available():
112
+ log_msg = self.delimiter.join(
113
+ [
114
+ header,
115
+ "[{0" + space_fmt + "}/{1}]",
116
+ "eta: {eta}",
117
+ "{meters}",
118
+ "time: {time}",
119
+ "data: {data}",
120
+ "max mem: {memory:.0f}",
121
+ ]
122
+ )
123
+ else:
124
+ log_msg = self.delimiter.join(
125
+ [
126
+ header,
127
+ "[{0" + space_fmt + "}/{1}]",
128
+ "eta: {eta}",
129
+ "{meters}",
130
+ "time: {time}",
131
+ "data: {data}",
132
+ ]
133
+ )
134
+ MB = 1024.0 * 1024.0
135
+ for obj in iterable:
136
+ data_time.update(time.time() - end)
137
+ yield obj
138
+ iter_time.update(time.time() - end)
139
+ if i % print_freq == 0:
140
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
141
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
142
+ if torch.cuda.is_available():
143
+ print(
144
+ log_msg.format(
145
+ i,
146
+ len(iterable),
147
+ eta=eta_string,
148
+ meters=str(self),
149
+ time=str(iter_time),
150
+ data=str(data_time),
151
+ memory=torch.cuda.max_memory_allocated() / MB,
152
+ )
153
+ )
154
+ else:
155
+ print(
156
+ log_msg.format(
157
+ i,
158
+ len(iterable),
159
+ eta=eta_string,
160
+ meters=str(self),
161
+ time=str(iter_time),
162
+ data=str(data_time),
163
+ )
164
+ )
165
+ i += 1
166
+ end = time.time()
167
+ total_time = time.time() - start_time
168
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
169
+ print(f"{header} Total time: {total_time_str}")
170
+
171
+
172
+ def is_dist_avail_and_initialized():
173
+ if not dist.is_available():
174
+ return False
175
+ if not dist.is_initialized():
176
+ return False
177
+ return True
178
+
179
+
180
+ def reduce_across_processes(val):
181
+ if not is_dist_avail_and_initialized():
182
+ # nothing to sync, but we still convert to tensor for consistency with the distributed case.
183
+ return torch.tensor(val)
184
+
185
+ t = torch.tensor(val, device="cuda")
186
+ dist.barrier()
187
+ dist.all_reduce(t)
188
+ return t
189
+
190
+
191
+ def accuracy(output, target, topk=(1,)):
192
+ """Computes the accuracy over the k top predictions for the specified
193
+ values of k."""
194
+ with torch.inference_mode():
195
+ maxk = max(topk)
196
+ batch_size = target.size(0)
197
+ if target.ndim == 2:
198
+ target = target.max(dim=1)[1]
199
+
200
+ _, pred = output.topk(maxk, 1, True, True)
201
+ pred = pred.t()
202
+ correct = pred.eq(target[None])
203
+
204
+ res = []
205
+ for k in topk:
206
+ correct_k = correct[:k].flatten().sum(dtype=torch.float32)
207
+ res.append(correct_k * (100.0 / batch_size))
208
+ return res