ai / yolo_predictor.py
neoguojing
init
68d34d0
raw
history blame contribute delete
No virus
2.57 kB
from pytorch_model_factory import TorchModelFactory
from PIL import Image
from typing import Any, Dict
from detectron2.structures import Instances
import torch
import threading
import gc
class YOLOPredictor:
_instance = None
_lock = threading.Lock()
def __new__(cls, cfg=None):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(YOLOPredictor, cls).__new__(cls)
cls._instance._initialize(cfg)
return cls._instance
def _initialize(self, cfg=None):
self.model = TorchModelFactory.create_yolo_detect_model()
def __call__(self, image):
"""
Args:
image (PIL image): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
if self.model is None:
return None
predictions = self.model([image])
return self._post_processor(predictions)
def _post_processor(self, output):
print("-------------------\n", output)
pil_images = []
result: Dict[str, Instances] = {
"instances": None
}
# TODO 只支持一个图片
for i, o in enumerate(output):
im_bgr = o.plot()
im_rgb = Image.fromarray(im_bgr[..., ::-1])
pil_images.append(im_rgb)
result["instances"] = Instances(o.orig_shape)
if o.boxes is not None:
print(o.boxes.xywh, o.boxes.xywh.shape)
result["instances"].pred_boxes = o.boxes.xywh
if o.masks is not None:
result["instances"].pred_masks = o.masks.xyn
if o.probs is not None:
result["instances"].scores = o.probs.top1
if o.keypoints is not None:
result["instances"].pred_keypoints = o.keypoints.xyn
if o.obb is not None:
result["instances"].pred_obb = o.obb.xywhr
return result, pil_images
def release(self):
# 删除模型对象
del self.model
# 清除GPU缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 手动触发垃圾回收
gc.collect()
# if __name__ == "__main__":
# f = YOLOPredictor()
# from PIL import Image
# img = Image.open("./test/test.png")
# f(img)