|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
import onnxruntime
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from mmdet.structures.bbox import bbox2roi
|
|
|
from mmengine import Config
|
|
|
from mmengine.registry import init_default_scope
|
|
|
from mmengine.runner import load_checkpoint
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
parser = argparse.ArgumentParser(description='Get model flops and params')
|
|
|
parser.add_argument('config', help='config file path')
|
|
|
parser.add_argument('checkpoint', help='checkpoint file')
|
|
|
parser.add_argument(
|
|
|
'--num_frames', type=int, default=8, help='number of input frames.')
|
|
|
parser.add_argument(
|
|
|
'--shape',
|
|
|
type=int,
|
|
|
nargs='+',
|
|
|
default=[256, 455],
|
|
|
help='input image size')
|
|
|
parser.add_argument(
|
|
|
'--device', type=str, default='cpu', help='CPU/CUDA device option')
|
|
|
parser.add_argument(
|
|
|
'--output_file',
|
|
|
type=str,
|
|
|
default='stdet.onnx',
|
|
|
help='file name of the output onnx file')
|
|
|
args = parser.parse_args()
|
|
|
return args
|
|
|
|
|
|
|
|
|
class SpatialMaxPool3d(nn.Module):
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x.max(dim=-1, keepdim=True)[0]
|
|
|
return x.max(dim=-2, keepdim=True)[0]
|
|
|
|
|
|
|
|
|
class SpatialAvgPool(nn.Module):
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
def forward(self, x):
|
|
|
return x.mean(dim=(-1, -2), keepdims=True)
|
|
|
|
|
|
|
|
|
class TemporalMaxPool3d(nn.Module):
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
def forward(self, x):
|
|
|
return x.max(dim=-3, keepdim=True)[0]
|
|
|
|
|
|
|
|
|
class TemporalAvgPool3d(nn.Module):
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
def forward(self, x):
|
|
|
return x.mean(dim=-3, keepdim=True)
|
|
|
|
|
|
|
|
|
class GlobalPool2d(nn.Module):
|
|
|
|
|
|
def __init__(self, pool_size, output_size, later_max=True):
|
|
|
super().__init__()
|
|
|
self.pool = nn.AvgPool2d(pool_size)
|
|
|
self.max = later_max
|
|
|
self.output_size = output_size
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.pool(x)
|
|
|
if self.max:
|
|
|
x = x.max(dim=-1, keepdim=True)[0]
|
|
|
x = x.max(dim=-2, keepdim=True)[0]
|
|
|
else:
|
|
|
x = x.mean(dim=(-1, -2), keepdims=True)
|
|
|
x = x.expand(-1, -1, self.output_size, self.output_size)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class STDet(nn.Module):
|
|
|
|
|
|
def __init__(self, base_model, input_tensor):
|
|
|
super(STDet, self).__init__()
|
|
|
self.backbone = base_model.backbone
|
|
|
self.bbox_roi_extractor = base_model.roi_head.bbox_roi_extractor
|
|
|
self.bbox_head = base_model.roi_head.bbox_head
|
|
|
|
|
|
output_size = self.bbox_roi_extractor.global_pool.output_size
|
|
|
pool_size = min(input_tensor.shape[-2:]) // 16 // output_size
|
|
|
|
|
|
if isinstance(self.bbox_head.temporal_pool, nn.AdaptiveAvgPool3d):
|
|
|
self.bbox_head.temporal_pool = TemporalAvgPool3d()
|
|
|
else:
|
|
|
self.bbox_head.temporal_pool = TemporalMaxPool3d()
|
|
|
if isinstance(self.bbox_head.spatial_pool, nn.AdaptiveAvgPool3d):
|
|
|
self.bbox_head.spatial_pool = SpatialAvgPool()
|
|
|
self.bbox_roi_extractor.global_pool = GlobalPool2d(
|
|
|
pool_size, output_size, later_max=False)
|
|
|
else:
|
|
|
self.bbox_head.spatial_pool = SpatialMaxPool3d()
|
|
|
self.bbox_roi_extractor.global_pool = GlobalPool2d(
|
|
|
pool_size, output_size, later_max=True)
|
|
|
|
|
|
def forward(self, input_tensor, rois):
|
|
|
feat = self.backbone(input_tensor)
|
|
|
bbox_feats, _ = self.bbox_roi_extractor(feat, rois)
|
|
|
cls_score = self.bbox_head(bbox_feats)
|
|
|
return cls_score
|
|
|
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
config = Config.fromfile(args.config)
|
|
|
|
|
|
if config.model.type != 'FastRCNN':
|
|
|
print('This script serves the sole purpose of converting spatial '
|
|
|
'temporal detection models in MMAction2 to ONNX files. Please '
|
|
|
'note that attempting to convert other models using this script '
|
|
|
'may not yield successful results.\n\n')
|
|
|
|
|
|
init_default_scope(config.get('default_scope', 'mmaction'))
|
|
|
|
|
|
base_model = MODELS.build(config.model)
|
|
|
load_checkpoint(base_model, args.checkpoint, map_location='cpu')
|
|
|
base_model.to(args.device)
|
|
|
|
|
|
if len(args.shape) == 1:
|
|
|
input_shape = (args.shape[0], args.shape[0])
|
|
|
elif len(args.shape) == 2:
|
|
|
input_shape = tuple(args.shape)
|
|
|
else:
|
|
|
raise ValueError('invalid input shape')
|
|
|
|
|
|
input_tensor = torch.randn(1, 3, args.num_frames, *input_shape)
|
|
|
input_tensor = input_tensor.clamp(-3, 3).to(args.device)
|
|
|
proposal = torch.Tensor([[22., 59., 67., 157.], [186., 73., 217., 159.],
|
|
|
[407., 95., 431., 168.]])
|
|
|
|
|
|
rois = bbox2roi([proposal]).to(args.device)
|
|
|
|
|
|
model = STDet(base_model, input_tensor).to(args.device)
|
|
|
model.eval()
|
|
|
cls_score = model(input_tensor, rois)
|
|
|
print(f'Model output shape: {cls_score.shape}')
|
|
|
|
|
|
torch.onnx.export(
|
|
|
model, (input_tensor, rois),
|
|
|
args.output_file,
|
|
|
input_names=['input_tensor', 'rois'],
|
|
|
output_names=['cls_score'],
|
|
|
export_params=True,
|
|
|
do_constant_folding=True,
|
|
|
verbose=False,
|
|
|
opset_version=11,
|
|
|
dynamic_axes={
|
|
|
'input_tensor': {
|
|
|
0: 'batch_size',
|
|
|
3: 'height',
|
|
|
4: 'width'
|
|
|
},
|
|
|
'rois': {
|
|
|
0: 'total_num_bbox_for_the_batch'
|
|
|
},
|
|
|
'cls_score': {
|
|
|
0: 'total_num_bbox_for_the_batch'
|
|
|
}
|
|
|
})
|
|
|
|
|
|
print(f'Successfully export the onnx file to {args.output_file}')
|
|
|
|
|
|
|
|
|
session = onnxruntime.InferenceSession(args.output_file)
|
|
|
input_feed = {
|
|
|
'input_tensor': input_tensor.cpu().data.numpy(),
|
|
|
'rois': rois.cpu().data.numpy()
|
|
|
}
|
|
|
outputs = session.run(['cls_score'], input_feed=input_feed)
|
|
|
outputs = outputs[0]
|
|
|
diff = abs(cls_score.cpu().data.numpy() - outputs).max()
|
|
|
if diff < 1e-5:
|
|
|
print('The output difference is smaller than 1e-5.')
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
|