File size: 2,574 Bytes
68d34d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

from pytorch_model_factory import TorchModelFactory
from PIL import Image
from typing import Any, Dict
from detectron2.structures import Instances
import torch
import threading
import gc

class YOLOPredictor:
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, cfg=None):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super(YOLOPredictor, cls).__new__(cls)
                    cls._instance._initialize(cfg)
        return cls._instance

    def _initialize(self, cfg=None):
        self.model = TorchModelFactory.create_yolo_detect_model()

    def __call__(self, image):
        """
        Args:
            image (PIL image): an image of shape (H, W, C) (in BGR order).

        Returns:
            predictions (dict):
                the output of the model for one image only.
                See :doc:`/tutorials/models` for details about the format.
        """

        if self.model is None:
            return None
    
        predictions = self.model([image])
        return self._post_processor(predictions)
    
    def _post_processor(self, output):
        print("-------------------\n", output)
        pil_images = []

        result: Dict[str, Instances] = {
            "instances": None
        }

        # TODO 只支持一个图片
        for i, o in enumerate(output):
            im_bgr = o.plot()
            im_rgb = Image.fromarray(im_bgr[..., ::-1])
            pil_images.append(im_rgb)
            
            result["instances"] = Instances(o.orig_shape)

            if o.boxes is not None:
                print(o.boxes.xywh, o.boxes.xywh.shape)
                result["instances"].pred_boxes = o.boxes.xywh

            if o.masks is not None:
                result["instances"].pred_masks = o.masks.xyn

            if o.probs is not None:
                result["instances"].scores = o.probs.top1

            if o.keypoints is not None:
                result["instances"].pred_keypoints = o.keypoints.xyn

            if o.obb is not None:
                result["instances"].pred_obb = o.obb.xywhr

        return result, pil_images

    def release(self):
        # 删除模型对象
        del self.model 
        # 清除GPU缓存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        # 手动触发垃圾回收
        gc.collect()



# if __name__ == "__main__":
#     f = YOLOPredictor()
#     from PIL import Image
#     img = Image.open("./test/test.png")
#     f(img)