File size: 3,750 Bytes
68d34d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

from pytorch_model_factory import TorchModelFactory
from detectron2.data import MetadataCatalog
import torch
import torchvision.transforms as transforms

class PytorchPredictor:

    def __init__(self, cfg):
        self.cfg = cfg.clone()  # cfg can be modified by model
        self.task_type = cfg.TASK_TYPE
        self.resize = 256
        self.crop = 224
        if self.task_type == "classfication":
            self.model = TorchModelFactory.create_feature_extract_model("resnet")
        elif self.task_type == "feature":
            self.model = TorchModelFactory.create_feature_extract_model("resnet")
        elif self.task_type == "semantic":
            self.model = TorchModelFactory.create_semantic_model("deeplabv3")
            self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
            if not hasattr(self.metadata,"stuff_classes"):
                self.metadata.stuff_classes = self.metadata.thing_classes
            if len(self.metadata.stuff_classes) == 20:
                self.metadata.stuff_classes.insert(0, "background")
            
            print(self.metadata)
            self.resize = None
            self.crop = None

    def __call__(self, image):
        """
        Args:
            image (PIL image): an image of shape (H, W, C) (in BGR order).

        Returns:
            predictions (dict):
                the output of the model for one image only.
                See :doc:`/tutorials/models` for details about the format.
        """

        if self.model is None:
            return None
        
        image = self.image_processor(image)
        input_batch = image.unsqueeze(0)

        if torch.cuda.is_available():
            input_batch = input_batch.cuda()
        
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            if self.model is None:
                return None
        
            predictions = self.model(input_batch)
            return self._post_processor(predictions)
        
    def image_processor(self,input_image):
        # from PIL import Image
        # input_image = Image.open(image_path).convert('RGB')
        preprocess = transforms.Compose([
            transforms.Resize(self.resize) if self.resize is not None else lambda x: x,
            transforms.CenterCrop(self.crop) if self.crop is not None else lambda x: x,
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        input_tensor = preprocess(input_image)
        
        return input_tensor
    
    def _post_processor(self,output):
        result = None
        if self.task_type == "classfication":
            output = output.cpu()
            result = {"classfication":[]}
            probabilities = torch.nn.functional.softmax(output, dim=1)
            for i, probabilitiy in enumerate(probabilities):
                top5_prob, top5_catid = torch.topk(probabilitiy, 1)
                target = {"feature":output[i],"score":top5_prob,"pred_class":top5_catid}
                result["classfication"].append(target)   
        elif self.task_type == "feature":
            output = output.cpu()
            result = {"features":output}
        elif self.task_type == "semantic":
            output = output["out"]
            output_predictions = output.argmax(1)
            output_predictions = output_predictions.cpu()
            result = {"sem_segs":output_predictions}

        return result

    def release(self):
        import gc
        # 删除模型对象
        del self.model 
        # 清除GPU缓存
        if self.cfg.MODEL.DEVICE == "gpu":
            torch.cuda.empty_cache()
        # 手动触发垃圾回收
        gc.collect()