Spaces:
Sleeping
Sleeping
File size: 2,574 Bytes
68d34d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from pytorch_model_factory import TorchModelFactory
from PIL import Image
from typing import Any, Dict
from detectron2.structures import Instances
import torch
import threading
import gc
class YOLOPredictor:
_instance = None
_lock = threading.Lock()
def __new__(cls, cfg=None):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(YOLOPredictor, cls).__new__(cls)
cls._instance._initialize(cfg)
return cls._instance
def _initialize(self, cfg=None):
self.model = TorchModelFactory.create_yolo_detect_model()
def __call__(self, image):
"""
Args:
image (PIL image): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
if self.model is None:
return None
predictions = self.model([image])
return self._post_processor(predictions)
def _post_processor(self, output):
print("-------------------\n", output)
pil_images = []
result: Dict[str, Instances] = {
"instances": None
}
# TODO 只支持一个图片
for i, o in enumerate(output):
im_bgr = o.plot()
im_rgb = Image.fromarray(im_bgr[..., ::-1])
pil_images.append(im_rgb)
result["instances"] = Instances(o.orig_shape)
if o.boxes is not None:
print(o.boxes.xywh, o.boxes.xywh.shape)
result["instances"].pred_boxes = o.boxes.xywh
if o.masks is not None:
result["instances"].pred_masks = o.masks.xyn
if o.probs is not None:
result["instances"].scores = o.probs.top1
if o.keypoints is not None:
result["instances"].pred_keypoints = o.keypoints.xyn
if o.obb is not None:
result["instances"].pred_obb = o.obb.xywhr
return result, pil_images
def release(self):
# 删除模型对象
del self.model
# 清除GPU缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 手动触发垃圾回收
gc.collect()
# if __name__ == "__main__":
# f = YOLOPredictor()
# from PIL import Image
# img = Image.open("./test/test.png")
# f(img)
|