Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
from argparse import ArgumentParser | |
import mmcv | |
import requests | |
import torch | |
from mmengine.structures import InstanceData | |
from mmdet.apis import inference_detector, init_detector | |
from mmdet.registry import VISUALIZERS | |
from mmdet.structures import DetDataSample | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument('img', help='Image file') | |
parser.add_argument('config', help='Config file') | |
parser.add_argument('checkpoint', help='Checkpoint file') | |
parser.add_argument('model_name', help='The model name in the server') | |
parser.add_argument( | |
'--inference-addr', | |
default='127.0.0.1:8080', | |
help='Address and port of the inference server') | |
parser.add_argument( | |
'--device', default='cuda:0', help='Device used for inference') | |
parser.add_argument( | |
'--score-thr', type=float, default=0.5, help='bbox score threshold') | |
parser.add_argument( | |
'--work-dir', | |
type=str, | |
default=None, | |
help='output directory to save drawn results.') | |
args = parser.parse_args() | |
return args | |
def align_ts_output(inputs, metainfo, device): | |
bboxes = [] | |
labels = [] | |
scores = [] | |
for i, pred in enumerate(inputs): | |
bboxes.append(pred['bbox']) | |
labels.append(pred['class_label']) | |
scores.append(pred['score']) | |
pred_instances = InstanceData(metainfo=metainfo) | |
pred_instances.bboxes = torch.tensor( | |
bboxes, dtype=torch.float32, device=device) | |
pred_instances.labels = torch.tensor( | |
labels, dtype=torch.int64, device=device) | |
pred_instances.scores = torch.tensor( | |
scores, dtype=torch.float32, device=device) | |
ts_data_sample = DetDataSample(pred_instances=pred_instances) | |
return ts_data_sample | |
def main(args): | |
# build the model from a config file and a checkpoint file | |
model = init_detector(args.config, args.checkpoint, device=args.device) | |
# test a single image | |
pytorch_results = inference_detector(model, args.img) | |
keep = pytorch_results.pred_instances.scores >= args.score_thr | |
pytorch_results.pred_instances = pytorch_results.pred_instances[keep] | |
# init visualizer | |
visualizer = VISUALIZERS.build(model.cfg.visualizer) | |
# the dataset_meta is loaded from the checkpoint and | |
# then pass to the model in init_detector | |
visualizer.dataset_meta = model.dataset_meta | |
# show the results | |
img = mmcv.imread(args.img) | |
img = mmcv.imconvert(img, 'bgr', 'rgb') | |
pt_out_file = None | |
ts_out_file = None | |
if args.work_dir is not None: | |
os.makedirs(args.work_dir, exist_ok=True) | |
pt_out_file = os.path.join(args.work_dir, 'pytorch_result.png') | |
ts_out_file = os.path.join(args.work_dir, 'torchserve_result.png') | |
visualizer.add_datasample( | |
'pytorch_result', | |
img.copy(), | |
data_sample=pytorch_results, | |
draw_gt=False, | |
out_file=pt_out_file, | |
show=True, | |
wait_time=0) | |
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name | |
with open(args.img, 'rb') as image: | |
response = requests.post(url, image) | |
metainfo = pytorch_results.pred_instances.metainfo | |
ts_results = align_ts_output(response.json(), metainfo, args.device) | |
visualizer.add_datasample( | |
'torchserve_result', | |
img, | |
data_sample=ts_results, | |
draw_gt=False, | |
out_file=ts_out_file, | |
show=True, | |
wait_time=0) | |
assert torch.allclose(pytorch_results.pred_instances.bboxes, | |
ts_results.pred_instances.bboxes) | |
assert torch.allclose(pytorch_results.pred_instances.labels, | |
ts_results.pred_instances.labels) | |
assert torch.allclose(pytorch_results.pred_instances.scores, | |
ts_results.pred_instances.scores) | |
if __name__ == '__main__': | |
args = parse_args() | |
main(args) | |