# import some common libraries import numpy as np # import some common detectron2 utilities from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.utils.visualizer import ColorMode import detectron2.data.transforms as T from predictor import InferenceBase import torch import torchvision.transforms as transforms from PIL import Image from detectron2.data.detection_utils import pil_image_handler # 定义模型类别的常量 class ModelCategory: IMAGE_FEATURE_EXTRACT = "image_feature_extract" IMAGE_CLASSIFICATION = "image_classification" OBJECT_DETECTION = "object_detection" ONE_STEP_OBJECT_DETECTION = "onestep_object_detection" SEMANTIC_SEGMENTATION = "semantic_segmentation" INSTANCE_SEGMENTATION = "instance_segmentation" PANOPTIC_SEGMENTATION = "panoptic_segmentation" KEYPOINTS = "keypoints" REGRESSION = "regression" TEXT_CLASSIFICATION = "text_classification" LANGUAGE_MODELLING = "language_modelling" TRANSLATION = "translation" QA_SYSTEM = "qa_system" RECOMMENDATION_SYSTEM = "recommendation_system" GENERATIVE_MODELLING = "generative_modelling" CONTROL = "control" ROBOTICS = "robotics" YOLO = "yolo" OTHERS = "others" class ModelConfig: cfg: None def __init__(self,model_type, model_path: str=None,cfg_path: str= None,thresh_hold: float = 0.5): self.cfg = get_cfg() if cfg_path is not None: self.cfg.merge_from_file(cfg_path) if model_path is not None: self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model_path) self.thresh_hold = thresh_hold if model_type == ModelCategory.IMAGE_FEATURE_EXTRACT: self.cfg.TASK_TYPE = "feature" self.cfg.MODEL.WEIGHTS = None elif model_type == ModelCategory.IMAGE_CLASSIFICATION: self.cfg.TASK_TYPE = "classfication" self.cfg.MODEL.WEIGHTS = None elif model_type == ModelCategory.SEMANTIC_SEGMENTATION: self.cfg.TASK_TYPE = "semantic" self.cfg.MODEL.WEIGHTS = None elif model_type == ModelCategory.YOLO: self.cfg.TASK_TYPE = "yolo" self.cfg.MODEL.WEIGHTS = None def get_cfg(self,): return self.cfg class ModelFactory: _instances = {} def __init__(self): self.need_save_images = False @classmethod def get_instance(cls, category, cfg): if category not in cls._instances: cls._instances[category] = InferenceBase(cfg) return cls._instances[category] def serialize(self,output): serialized = None # print(output) if "instances" in output: serialized = { 'image_height': output["instances"].image_size[0], 'image_width': output["instances"].image_size[1], 'pred_boxes': output["instances"].pred_boxes.tolist() if isinstance(output["instances"].pred_boxes, torch.Tensor) else output["instances"].pred_boxes.tensor.tolist(), 'scores': output["instances"].scores.tolist() if output["instances"].has("scores") else None, 'pred_classes': output["instances"].pred_classes.tolist() if output["instances"].has("pred_classes") else None } if hasattr(output["instances"],"pred_masks"): # serialized["pred_masks"] = output["instances"].pred_masks.tolist() print("instances.pred_masks",output["instances"].pred_masks.shape) if hasattr(output["instances"],"pred_keypoints"): serialized["pred_keypoints"] = output["instances"].pred_keypoints.tolist() if "sem_seg" in output: # serialized["sem_seg"] = output["sem_seg"].tolist() print("sem_seg:",output["sem_seg"].shape) if "panoptic_seg" in output: print("panoptic_seg:",output["panoptic_seg"][0].shape) # print("panoptic_seg:",output["panoptic_seg"]) serialized["panoptic_seg"] = output["panoptic_seg"][1] if "sem_segs" in output: print("sem_segs:",output["sem_segs"].shape) if "classfication" in output: serialized = [] for item in output["classfication"]: print("classfication:",item["feature"].shape) row = { # "feature": item["feature"].tolist(), "score": item["score"].tolist(), "pred_class": item["pred_class"].tolist(), } serialized.append(row) if "features" in output: print("features:",output["features"].shape) serialized = { "features":output["features"].tolist(), } if serialized is None: return output return serialized def predict(self,pil_image,task_type="panoptic"): result = None vis_output = None if task_type == "panoptic": result,vis_output = self.panoptic_segment(input_image=pil_image) elif task_type == "detect": result,vis_output = self.detect(input_image=pil_image) elif task_type == "classification": result = self.classify(input_image=pil_image) elif task_type == "instance": result,vis_output = self.instance_segment(input_image=pil_image) elif task_type == "semantic": result,vis_output = self.semantic_segment(input_image=pil_image) elif task_type == "feature": result = self.extract(input_image=pil_image) elif task_type == "keypoint": result,vis_output = self.keypoint(input_image=pil_image) elif task_type == "onestep_detect": result,vis_output = self.onstep_detect(input_image=pil_image) elif task_type == "yolo": result,vis_output = self.yolo(input_image=pil_image) return self.serialize(result),vis_output def extract(self, input_image=None,image_path: str="./test.png"): """ Perform classification on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.IMAGE_FEATURE_EXTRACT, model_path=None, cfg_path=None).get_cfg() p = self.get_instance(ModelCategory.IMAGE_FEATURE_EXTRACT,cfg) if input_image is None and image_path is not None: input_image = Image.open(image_path).convert('RGB') input_image = pil_image_handler(input_image) outputs,_ = p.run_on_image(input_image) return outputs def classify(self, input_image=None,image_path: str="./cat.jpg"): """ Perform classification on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.IMAGE_CLASSIFICATION, model_path=None, cfg_path=None).get_cfg() p = self.get_instance(ModelCategory.IMAGE_CLASSIFICATION,cfg) if input_image is None and image_path is not None: input_image = Image.open(image_path).convert('RGB') input_image = pil_image_handler(input_image) outputs,_ = p.run_on_image(input_image) return outputs def onstep_detect(self, input_image=None,image_path: str= "./test.png", confidence_threshold: float = 0.5): """ Perform on step object detection on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.ONE_STEP_OBJECT_DETECTION, model_path="COCO-Detection/retinanet_R_101_FPN_3x.yaml", cfg_path="./configs/COCO-Detection/retinanet_R_101_FPN_3x.yaml").get_cfg() p = self.get_instance(ModelCategory.ONE_STEP_OBJECT_DETECTION,cfg) if input_image is None and image_path is not None: input_image = p.read_image(image_path) else: input_image = pil_image_handler(input_image) outputs,vis_output = p.run_on_image(input_image) return outputs,vis_output def detect(self,input_image=None, image_path: str = "./test.png", confidence_threshold: float = 0.5): """ Perform object detection on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.OBJECT_DETECTION, model_path="COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml", cfg_path="./configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml").get_cfg() p = self.get_instance(ModelCategory.IMAGE_FEATURE_EXTRACT, cfg) if input_image is None and image_path is not None: input_image = p.read_image(image_path) else: input_image = pil_image_handler(input_image) outputs,vis_output = p.run_on_image(input_image) return outputs,vis_output def instance_segment(self,input_image=None, image_path: str="./test.png"): """ Perform instance segmentation on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.INSTANCE_SEGMENTATION, model_path="COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml", cfg_path="./configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml").get_cfg() p = self.get_instance(ModelCategory.INSTANCE_SEGMENTATION,cfg) if input_image is None and image_path is not None: input_image = p.read_image(image_path) else: input_image = pil_image_handler(input_image) outputs,vis_output = p.run_on_image(input_image) return outputs,vis_output def semantic_segment(self,input_image=None, image_path: str="./test.png"): """ Perform instance segmentation on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.SEMANTIC_SEGMENTATION, model_path=None, cfg_path="./configs/PascalVOC-Detection/faster_rcnn_R_50_FPN.yaml").get_cfg() p = self.get_instance(ModelCategory.SEMANTIC_SEGMENTATION,cfg) if input_image is None and image_path is not None: input_image = Image.open(image_path).convert('RGB') input_image = pil_image_handler(input_image) outputs,vis_output = p.run_on_image(input_image) return outputs,vis_output def panoptic_segment(self,input_image=None, image_path: str="./test.png"): """ Perform panoptic segmentation on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.INSTANCE_SEGMENTATION, model_path="COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml", cfg_path="./configs/COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml").get_cfg() p = self.get_instance(ModelCategory.INSTANCE_SEGMENTATION,cfg) if input_image is None and image_path is not None: input_image = p.read_image(image_path) else: input_image = pil_image_handler(input_image) outputs,vis_output = p.run_on_image(input_image) # outputs['sem_seg'] = outputs['sem_seg'].numpy().tolist() return outputs,vis_output def keypoint(self, input_image=None,image_path: str="./test.png"): """ Perform keypoint on an image using Detectron2. """ cfg = ModelConfig(ModelCategory.KEYPOINTS, model_path="COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml", cfg_path="./configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml").get_cfg() p = self.get_instance(ModelCategory.KEYPOINTS,cfg) if input_image is None and image_path is not None: input_image = p.read_image(image_path) else: input_image = pil_image_handler(input_image) outputs,vis_output = p.run_on_image(input_image) return outputs,vis_output def yolo(self, input_image=None,image_path: str="./test/test.png"): cfg = ModelConfig(ModelCategory.YOLO, model_path=None, cfg_path=None).get_cfg() p = self.get_instance(ModelCategory.YOLO,cfg) if input_image is None and image_path is not None: input_image = Image.open(image_path).convert('RGB') input_image = pil_image_handler(input_image) outputs,vis_output = p.run_on_image(input_image) return outputs,vis_output # if __name__ == "__main__": # f = ModelFactory() # # f.prepare_meta() # out = f.yolo() # print(out)