|
import argparse |
|
|
|
import onnx |
|
import torch |
|
from fvcore.nn import FlopCountAnalysis |
|
from onnx2torch import convert |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model-ckpt", |
|
type=str, |
|
help="Model checkpoint. Both PyTorch and ONNX models can be used.", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
if args.model_ckpt.endswith(".onnx"): |
|
onnx_model = onnx.load(args.model_ckpt) |
|
model = convert(onnx_model) |
|
elif args.model_ckpt.endswith((".pth", ".pt")): |
|
checkpoint = torch.load(args.model_ckpt, map_location="cpu") |
|
model = checkpoint["model_ckpt"] |
|
if "model_ema" in checkpoint: |
|
state_dict = {} |
|
for key, value in checkpoint["model_ema"].items(): |
|
if not "module." in key: |
|
continue |
|
state_dict[key.replace("module.", "")] = value |
|
model.load_state_dict(state_dict) |
|
else: |
|
raise RuntimeError( |
|
f"Cannot process file {args.model_ckpt} with extension {args.model_ckpt.split('.')[-1]}" |
|
) |
|
|
|
model.eval() |
|
|
|
flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224))) |
|
flops = flops.total() |
|
|
|
print(f"MMACs = {flops/1e6}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|