imagenet-benchmark / measure_mac.py
savchenkoyana's picture
add ViT x4.8 ONNX, small fixes in test.py, and allow measuring macs on ONNX
061bac4
import argparse
import onnx
import torch
from fvcore.nn import FlopCountAnalysis
from onnx2torch import convert
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-ckpt",
type=str,
help="Model checkpoint. Both PyTorch and ONNX models can be used.",
)
return parser.parse_args()
def main():
args = get_args()
if args.model_ckpt.endswith(".onnx"):
onnx_model = onnx.load(args.model_ckpt)
model = convert(onnx_model)
elif args.model_ckpt.endswith((".pth", ".pt")):
checkpoint = torch.load(args.model_ckpt, map_location="cpu")
model = checkpoint["model_ckpt"]
if "model_ema" in checkpoint:
state_dict = {}
for key, value in checkpoint["model_ema"].items():
if not "module." in key:
continue
state_dict[key.replace("module.", "")] = value
model.load_state_dict(state_dict)
else:
raise RuntimeError(
f"Cannot process file {args.model_ckpt} with extension {args.model_ckpt.split('.')[-1]}"
)
model.eval()
flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224)))
flops = flops.total()
print(f"MMACs = {flops/1e6}")
if __name__ == "__main__":
main()