Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import torch | |
| from fvcore.nn import FlopCountAnalysis, flop_count_table | |
| from mmengine import Config | |
| from mmengine.registry import init_default_scope | |
| from mmocr.registry import MODELS | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Train a detector') | |
| parser.add_argument('config', help='train config file path') | |
| parser.add_argument( | |
| '--shape', | |
| type=int, | |
| nargs='+', | |
| default=[640, 640], | |
| help='input image size') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| if len(args.shape) == 1: | |
| h = w = args.shape[0] | |
| elif len(args.shape) == 2: | |
| h, w = args.shape | |
| else: | |
| raise ValueError('invalid input shape, please use --shape h w') | |
| input_shape = (1, 3, h, w) | |
| cfg = Config.fromfile(args.config) | |
| init_default_scope(cfg.get('default_scope', 'mmocr')) | |
| model = MODELS.build(cfg.model) | |
| flops = FlopCountAnalysis(model, torch.ones(input_shape)) | |
| # params = parameter_count_table(model) | |
| flops_data = flop_count_table(flops) | |
| print(flops_data) | |
| print('!!!Please be cautious if you use the results in papers. ' | |
| 'You may need to check if all ops are supported and verify that the ' | |
| 'flops computation is correct.') | |
| if __name__ == '__main__': | |
| main() | |