Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import torch | |
from fvcore.nn import FlopCountAnalysis, flop_count_table | |
from mmengine import Config | |
from mmengine.registry import init_default_scope | |
from mmocr.registry import MODELS | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a detector') | |
parser.add_argument('config', help='train config file path') | |
parser.add_argument( | |
'--shape', | |
type=int, | |
nargs='+', | |
default=[640, 640], | |
help='input image size') | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
if len(args.shape) == 1: | |
h = w = args.shape[0] | |
elif len(args.shape) == 2: | |
h, w = args.shape | |
else: | |
raise ValueError('invalid input shape, please use --shape h w') | |
input_shape = (1, 3, h, w) | |
cfg = Config.fromfile(args.config) | |
init_default_scope(cfg.get('default_scope', 'mmocr')) | |
model = MODELS.build(cfg.model) | |
flops = FlopCountAnalysis(model, torch.ones(input_shape)) | |
# params = parameter_count_table(model) | |
flops_data = flop_count_table(flops) | |
print(flops_data) | |
print('!!!Please be cautious if you use the results in papers. ' | |
'You may need to check if all ops are supported and verify that the ' | |
'flops computation is correct.') | |
if __name__ == '__main__': | |
main() | |