import argparse import onnx import torch from fvcore.nn import FlopCountAnalysis from onnx2torch import convert def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model-ckpt", type=str, help="Model checkpoint. Both PyTorch and ONNX models can be used.", ) return parser.parse_args() def main(): args = get_args() if args.model_ckpt.endswith(".onnx"): onnx_model = onnx.load(args.model_ckpt) model = convert(onnx_model) elif args.model_ckpt.endswith((".pth", ".pt")): checkpoint = torch.load(args.model_ckpt, map_location="cpu") model = checkpoint["model_ckpt"] if "model_ema" in checkpoint: state_dict = {} for key, value in checkpoint["model_ema"].items(): if not "module." in key: continue state_dict[key.replace("module.", "")] = value model.load_state_dict(state_dict) else: raise RuntimeError( f"Cannot process file {args.model_ckpt} with extension {args.model_ckpt.split('.')[-1]}" ) model.eval() flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224))) flops = flops.total() print(f"MMACs = {flops/1e6}") if __name__ == "__main__": main()