File size: 1,326 Bytes
4d679c2
 
061bac4
4d679c2
 
061bac4
4d679c2
 
 
 
061bac4
 
 
 
 
4d679c2
 
 
 
 
 
 
061bac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d679c2
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse

import onnx
import torch
from fvcore.nn import FlopCountAnalysis
from onnx2torch import convert


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-ckpt",
        type=str,
        help="Model checkpoint. Both PyTorch and ONNX models can be used.",
    )

    return parser.parse_args()


def main():
    args = get_args()

    if args.model_ckpt.endswith(".onnx"):
        onnx_model = onnx.load(args.model_ckpt)
        model = convert(onnx_model)
    elif args.model_ckpt.endswith((".pth", ".pt")):
        checkpoint = torch.load(args.model_ckpt, map_location="cpu")
        model = checkpoint["model_ckpt"]
        if "model_ema" in checkpoint:
            state_dict = {}
            for key, value in checkpoint["model_ema"].items():
                if not "module." in key:
                    continue
                state_dict[key.replace("module.", "")] = value
            model.load_state_dict(state_dict)
    else:
        raise RuntimeError(
            f"Cannot process file {args.model_ckpt} with extension {args.model_ckpt.split('.')[-1]}"
        )

    model.eval()

    flops = FlopCountAnalysis(model.cpu(), torch.ones((1, 3, 224, 224)))
    flops = flops.total()

    print(f"MMACs = {flops/1e6}")


if __name__ == "__main__":
    main()