rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from argparse import ArgumentParser
import mmcv
import requests
import torch
from mmengine.structures import InstanceData
from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample
def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('model_name', help='The model name in the server')
parser.add_argument(
'--inference-addr',
default='127.0.0.1:8080',
help='Address and port of the inference server')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.5, help='bbox score threshold')
parser.add_argument(
'--work-dir',
type=str,
default=None,
help='output directory to save drawn results.')
args = parser.parse_args()
return args
def align_ts_output(inputs, metainfo, device):
bboxes = []
labels = []
scores = []
for i, pred in enumerate(inputs):
bboxes.append(pred['bbox'])
labels.append(pred['class_label'])
scores.append(pred['score'])
pred_instances = InstanceData(metainfo=metainfo)
pred_instances.bboxes = torch.tensor(
bboxes, dtype=torch.float32, device=device)
pred_instances.labels = torch.tensor(
labels, dtype=torch.int64, device=device)
pred_instances.scores = torch.tensor(
scores, dtype=torch.float32, device=device)
ts_data_sample = DetDataSample(pred_instances=pred_instances)
return ts_data_sample
def main(args):
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# test a single image
pytorch_results = inference_detector(model, args.img)
keep = pytorch_results.pred_instances.scores >= args.score_thr
pytorch_results.pred_instances = pytorch_results.pred_instances[keep]
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta
# show the results
img = mmcv.imread(args.img)
img = mmcv.imconvert(img, 'bgr', 'rgb')
pt_out_file = None
ts_out_file = None
if args.work_dir is not None:
os.makedirs(args.work_dir, exist_ok=True)
pt_out_file = os.path.join(args.work_dir, 'pytorch_result.png')
ts_out_file = os.path.join(args.work_dir, 'torchserve_result.png')
visualizer.add_datasample(
'pytorch_result',
img.copy(),
data_sample=pytorch_results,
draw_gt=False,
out_file=pt_out_file,
show=True,
wait_time=0)
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
with open(args.img, 'rb') as image:
response = requests.post(url, image)
metainfo = pytorch_results.pred_instances.metainfo
ts_results = align_ts_output(response.json(), metainfo, args.device)
visualizer.add_datasample(
'torchserve_result',
img,
data_sample=ts_results,
draw_gt=False,
out_file=ts_out_file,
show=True,
wait_time=0)
assert torch.allclose(pytorch_results.pred_instances.bboxes,
ts_results.pred_instances.bboxes)
assert torch.allclose(pytorch_results.pred_instances.labels,
ts_results.pred_instances.labels)
assert torch.allclose(pytorch_results.pred_instances.scores,
ts_results.pred_instances.scores)
if __name__ == '__main__':
args = parse_args()
main(args)