File size: 4,242 Bytes
a083fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""ModelContainer class used for loading the model in the model wrapper."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, NamedTuple

from model_api.adapters import OpenvinoAdapter, create_core
from model_api.models import Model

from .utils import get_model_path, get_parameters

if TYPE_CHECKING:
    from pathlib import Path

    import numpy as np
    from model_api.tilers import DetectionTiler, InstanceSegmentationTiler


class TaskType(str, Enum):
    """OTX task type definition."""

    CLASSIFICATION = "CLASSIFICATION"
    DETECTION = "DETECTION"
    INSTANCE_SEGMENTATION = "INSTANCE_SEGMENTATION"
    SEGMENTATION = "SEGMENTATION"


class ModelWrapper:
    """Class for storing the model wrapper based on Model API and needed parameters of model.

    Args:
        model_dir (Path): path to model directory
    """

    def __init__(self, model_dir: Path, device: str = "CPU") -> None:
        model_adapter = OpenvinoAdapter(create_core(), get_model_path(model_dir / "model.xml"), device=device)
        if not (model_dir / "config.json").exists():
            msg = "config.json doesn't exist in the model directory."
            raise RuntimeError(msg)
        self.parameters = get_parameters(model_dir / "config.json")
        self._labels = self.parameters["model_parameters"]["labels"]
        self._task_type = TaskType[self.parameters["task_type"].upper()]

        # labels for modelAPI wrappers can be empty, because unused in pre- and postprocessing
        self.model_parameters = self.parameters["model_parameters"]

        # model already contains correct labels
        self.model_parameters.pop("labels")

        self.core_model = Model.create_model(
            model_adapter,
            self.parameters["model_type"],
            self.model_parameters,
            preload=True,
        )
        self.tiler = self.setup_tiler(model_dir, device)

    def setup_tiler(
        self,
        model_dir: Path,
        device: str,
    ) -> DetectionTiler | InstanceSegmentationTiler | None:
        """Set up tiler for model.

        Args:
            model_dir (str): model directory
            device (str): device to run model on
        Returns:
            Optional: type of tiler or None
        """
        if not self.parameters.get("tiling_parameters") or not self.parameters["tiling_parameters"]["enable_tiling"]:
            return None

        msg = "Tiling has not been implemented yet"
        raise NotImplementedError(msg)

    @property
    def task_type(self) -> TaskType:
        """Task type property."""
        return self._task_type

    @property
    def labels(self) -> dict:
        """Labels property."""
        return self._labels

    def infer(self, frame: np.ndarray) -> tuple[NamedTuple, dict]:
        """Infer with original image.

        Args:
            frame: np.ndarray, input image
        Returns:
            predictions: NamedTuple, prediction
            frame_meta: Dict, dict with original shape
        """
        # getting result include preprocessing, infer, postprocessing for sync infer
        predictions = self.core_model(frame)
        frame_meta = {"original_shape": frame.shape}

        return predictions, frame_meta

    def infer_tile(self, frame: np.ndarray) -> tuple[NamedTuple, dict]:
        """Infer by patching full image to tiles.

        Args:
            frame: np.ndarray - input image
        Returns:
            Tuple[NamedTuple, Dict]: prediction and original shape
        """
        if self.tiler is None:
            msg = "Tiler is not set"
            raise RuntimeError(msg)
        detections = self.tiler(frame)
        return detections, {"original_shape": frame.shape}

    def __call__(self, input_data: np.ndarray) -> tuple[Any, dict]:
        """Call the ModelWrapper class.

        Args:
            input_data (np.ndarray): The input image.

        Returns:
            Tuple[Any, dict]: A tuple containing predictions and the meta information.
        """
        if self.tiler is not None:
            return self.infer_tile(input_data)
        return self.infer(input_data)