import base64 import os import mmcv import torch from ts.torch_handler.base_handler import BaseHandler from mmdet.apis import inference_detector, init_detector class MMdetHandler(BaseHandler): threshold = 0.5 def initialize(self, context): properties = context.system_properties self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(self.map_location + ':' + str(properties.get('gpu_id')) if torch.cuda. is_available() else self.map_location) self.manifest = context.manifest model_dir = properties.get('model_dir') serialized_file = self.manifest['model']['serializedFile'] checkpoint = os.path.join(model_dir, serialized_file) self.config_file = os.path.join(model_dir, 'config.py') self.model = init_detector(self.config_file, checkpoint, self.device) self.initialized = True def preprocess(self, data): images = [] for row in data: image = row.get('data') or row.get('body') if isinstance(image, str): image = base64.b64decode(image) image = mmcv.imfrombytes(image) images.append(image) return images def inference(self, data, *args, **kwargs): results = inference_detector(self.model, data) return results def postprocess(self, data): # Format output following the example ObjectDetectionHandler format output = [] for image_index, image_result in enumerate(data): output.append([]) if isinstance(image_result, tuple): bbox_result, segm_result = image_result if isinstance(segm_result, tuple): segm_result = segm_result[0] # ms rcnn else: bbox_result, segm_result = image_result, None for class_index, class_result in enumerate(bbox_result): class_name = self.model.CLASSES[class_index] for bbox in class_result: bbox_coords = bbox[:-1].tolist() score = float(bbox[-1]) if score >= self.threshold: output[image_index].append({ class_name: bbox_coords, 'score': score }) return output