savchenkoyana commited on
Commit
061bac4
•
1 Parent(s): 778986f

add ViT x4.8 ONNX, small fixes in test.py, and allow measuring macs on ONNX

Browse files
README.md CHANGED
@@ -28,7 +28,8 @@ Evaluation code is also based on Torchvision references.
28
  | Model | Latency (MMACs) | Accuracy (%) |
29
  |--------------------------|:---------------:|:-------------:|
30
  | **ViT-B/32 Torchvision** | 4413.99 | 75.91 |
31
- | **ViT-B/32 ENOT** | 492.23 (x8.97) | 73.72 (-2.19) |
 
32
 
33
  ## MobileNetV2
34
 
 
28
  | Model | Latency (MMACs) | Accuracy (%) |
29
  |--------------------------|:---------------:|:-------------:|
30
  | **ViT-B/32 Torchvision** | 4413.99 | 75.91 |
31
+ | **ViT-B/32 ENOT (x4.8)** | 911.80 (x4.84) | 75.68 (-0.23) |
32
+ | **ViT-B/32 ENOT (x9)** | 490.78 (x8.99) | 73.72 (-2.19) |
33
 
34
  ## MobileNetV2
35
 
ViT-B-32/{ViT-B-32-ENOT.pth → ViT-B-32-ENOT-x4_8.onnx} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:92a81cef913af4012215215400049317168939624f09d76e9043aee2342af356
3
- size 157444613
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3597a973923ab2be41e046ef08cbccadd67279853e78185194f906086063e626
3
+ size 72211694
ViT-B-32/{ViT-B-32-ENOT.onnx → ViT-B-32-ENOT-x9.onnx} RENAMED
File without changes
measure_mac.py CHANGED
@@ -1,12 +1,18 @@
1
  import argparse
2
 
 
3
  import torch
4
  from fvcore.nn import FlopCountAnalysis
 
5
 
6
 
7
  def get_args():
8
  parser = argparse.ArgumentParser()
9
- parser.add_argument("--model-ckpt", type=str)
 
 
 
 
10
 
11
  return parser.parse_args()
12
 
@@ -14,8 +20,24 @@ def get_args():
14
  def main():
15
  args = get_args()
16
 
17
- checkpoint = torch.load(args.model_ckpt, map_location="cpu")
18
- model = checkpoint["model_ckpt"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  model.eval()
20
 
21
  flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224)))
 
1
  import argparse
2
 
3
+ import onnx
4
  import torch
5
  from fvcore.nn import FlopCountAnalysis
6
+ from onnx2torch import convert
7
 
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "--model-ckpt",
13
+ type=str,
14
+ help="Model checkpoint. Both PyTorch and ONNX models can be used.",
15
+ )
16
 
17
  return parser.parse_args()
18
 
 
20
  def main():
21
  args = get_args()
22
 
23
+ if args.model_ckpt.endswith(".onnx"):
24
+ onnx_model = onnx.load(args.model_ckpt)
25
+ model = convert(onnx_model)
26
+ elif args.model_ckpt.endswith((".pth", ".pt")):
27
+ checkpoint = torch.load(args.model_ckpt, map_location="cpu")
28
+ model = checkpoint["model_ckpt"]
29
+ if "model_ema" in checkpoint:
30
+ state_dict = {}
31
+ for key, value in checkpoint["model_ema"].items():
32
+ if not "module." in key:
33
+ continue
34
+ state_dict[key.replace("module.", "")] = value
35
+ model.load_state_dict(state_dict)
36
+ else:
37
+ raise RuntimeError(
38
+ f"Cannot process file {args.model_ckpt} with extension {args.model_ckpt.split('.')[-1]}"
39
+ )
40
+
41
  model.eval()
42
 
43
  flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224)))
requirements.txt CHANGED
@@ -3,3 +3,4 @@ torchvision==0.14.1
3
  fvcore==0.1.5.post20221221
4
  onnxruntime-gpu==1.15.1
5
  onnx==1.13.1
 
 
3
  fvcore==0.1.5.post20221221
4
  onnxruntime-gpu==1.15.1
5
  onnx==1.13.1
6
+ onnx2torch==1.5.6
test.py CHANGED
@@ -96,6 +96,9 @@ def load_data(valdir):
96
  def main(args):
97
  print(args)
98
 
 
 
 
99
  if torch.cuda.is_available():
100
  device = torch.device("cuda")
101
  else:
@@ -128,6 +131,7 @@ def main(args):
128
  state_dict[key.replace("module.", "")] = value
129
  model.load_state_dict(state_dict)
130
  model = model.to(device)
 
131
 
132
  accuracy = evaluate(
133
  model=model,
 
96
  def main(args):
97
  print(args)
98
 
99
+ torch.backends.cudnn.benchmark = False
100
+ torch.backends.cudnn.deterministic = True
101
+
102
  if torch.cuda.is_available():
103
  device = torch.device("cuda")
104
  else:
 
131
  state_dict[key.replace("module.", "")] = value
132
  model.load_state_dict(state_dict)
133
  model = model.to(device)
134
+ model.eval()
135
 
136
  accuracy = evaluate(
137
  model=model,