import argparse import mmcv import numpy as np import torch import torch._C import torch.serialization from mmcv.runner import load_checkpoint from torch import nn from mmseg.models import build_segmentor torch.manual_seed(3) def digit_version(version_str): digit_version = [] for x in version_str.split('.'): if x.isdigit(): digit_version.append(int(x)) elif x.find('rc') != -1: patch_version = x.split('rc') digit_version.append(int(patch_version[0]) - 1) digit_version.append(int(patch_version[1])) return digit_version def check_torch_version(): torch_minimum_version = '1.8.0' torch_version = digit_version(torch.__version__) assert (torch_version >= digit_version(torch_minimum_version)), \ f'Torch=={torch.__version__} is not support for converting to ' \ f'torchscript. Please install pytorch>={torch_minimum_version}.' def _convert_batchnorm(module): module_output = module if isinstance(module, torch.nn.SyncBatchNorm): module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats) if module.affine: module_output.weight.data = module.weight.data.clone().detach() module_output.bias.data = module.bias.data.clone().detach() # keep requires_grad unchanged module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked for name, child in module.named_children(): module_output.add_module(name, _convert_batchnorm(child)) del module return module_output def _demo_mm_inputs(input_shape, num_classes): """Create a superset of inputs needed to run test or train batches. Args: input_shape (tuple): input batch dimensions num_classes (int): number of semantic classes """ (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) segs = rng.randint( low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) img_metas = [{ 'img_shape': (H, W, C), 'ori_shape': (H, W, C), 'pad_shape': (H, W, C), 'filename': '.png', 'scale_factor': 1.0, 'flip': False, } for _ in range(N)] mm_inputs = { 'imgs': torch.FloatTensor(imgs).requires_grad_(True), 'img_metas': img_metas, 'gt_semantic_seg': torch.LongTensor(segs) } return mm_inputs def pytorch2libtorch(model, input_shape, show=False, output_file='tmp.pt', verify=False): """Export Pytorch model to TorchScript model and verify the outputs are same between Pytorch and TorchScript. Args: model (nn.Module): Pytorch model we want to export. input_shape (tuple): Use this input shape to construct the corresponding dummy input and execute the model. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output TorchScript model. Default: `tmp.pt`. verify (bool): Whether compare the outputs between Pytorch and TorchScript. Default: False. """ if isinstance(model.decode_head, nn.ModuleList): num_classes = model.decode_head[-1].num_classes else: num_classes = model.decode_head.num_classes mm_inputs = _demo_mm_inputs(input_shape, num_classes) imgs = mm_inputs.pop('imgs') # replace the orginal forword with forward_dummy model.forward = model.forward_dummy model.eval() traced_model = torch.jit.trace( model, example_inputs=imgs, check_trace=verify, ) if show: print(traced_model.graph) traced_model.save(output_file) print('Successfully exported TorchScript model: {}'.format(output_file)) def parse_args(): parser = argparse.ArgumentParser( description='Convert MMSeg to TorchScript') parser.add_argument('config', help='test config file path') parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument( '--show', action='store_true', help='show TorchScript graph') parser.add_argument( '--verify', action='store_true', help='verify the TorchScript model') parser.add_argument('--output-file', type=str, default='tmp.pt') parser.add_argument( '--shape', type=int, nargs='+', default=[512, 512], help='input image size (height, width)') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() check_torch_version() if len(args.shape) == 1: input_shape = (1, 3, args.shape[0], args.shape[0]) elif len(args.shape) == 2: input_shape = ( 1, 3, ) + tuple(args.shape) else: raise ValueError('invalid input shape') cfg = mmcv.Config.fromfile(args.config) cfg.model.pretrained = None # build the model and load checkpoint cfg.model.train_cfg = None segmentor = build_segmentor( cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) # convert SyncBN to BN segmentor = _convert_batchnorm(segmentor) if args.checkpoint: load_checkpoint(segmentor, args.checkpoint, map_location='cpu') # convert the PyTorch model to LibTorch model pytorch2libtorch( segmentor, input_shape, show=args.show, output_file=args.output_file, verify=args.verify)