Spaces:
Running
Running
File size: 4,242 Bytes
a083fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""ModelContainer class used for loading the model in the model wrapper."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, NamedTuple
from model_api.adapters import OpenvinoAdapter, create_core
from model_api.models import Model
from .utils import get_model_path, get_parameters
if TYPE_CHECKING:
from pathlib import Path
import numpy as np
from model_api.tilers import DetectionTiler, InstanceSegmentationTiler
class TaskType(str, Enum):
"""OTX task type definition."""
CLASSIFICATION = "CLASSIFICATION"
DETECTION = "DETECTION"
INSTANCE_SEGMENTATION = "INSTANCE_SEGMENTATION"
SEGMENTATION = "SEGMENTATION"
class ModelWrapper:
"""Class for storing the model wrapper based on Model API and needed parameters of model.
Args:
model_dir (Path): path to model directory
"""
def __init__(self, model_dir: Path, device: str = "CPU") -> None:
model_adapter = OpenvinoAdapter(create_core(), get_model_path(model_dir / "model.xml"), device=device)
if not (model_dir / "config.json").exists():
msg = "config.json doesn't exist in the model directory."
raise RuntimeError(msg)
self.parameters = get_parameters(model_dir / "config.json")
self._labels = self.parameters["model_parameters"]["labels"]
self._task_type = TaskType[self.parameters["task_type"].upper()]
# labels for modelAPI wrappers can be empty, because unused in pre- and postprocessing
self.model_parameters = self.parameters["model_parameters"]
# model already contains correct labels
self.model_parameters.pop("labels")
self.core_model = Model.create_model(
model_adapter,
self.parameters["model_type"],
self.model_parameters,
preload=True,
)
self.tiler = self.setup_tiler(model_dir, device)
def setup_tiler(
self,
model_dir: Path,
device: str,
) -> DetectionTiler | InstanceSegmentationTiler | None:
"""Set up tiler for model.
Args:
model_dir (str): model directory
device (str): device to run model on
Returns:
Optional: type of tiler or None
"""
if not self.parameters.get("tiling_parameters") or not self.parameters["tiling_parameters"]["enable_tiling"]:
return None
msg = "Tiling has not been implemented yet"
raise NotImplementedError(msg)
@property
def task_type(self) -> TaskType:
"""Task type property."""
return self._task_type
@property
def labels(self) -> dict:
"""Labels property."""
return self._labels
def infer(self, frame: np.ndarray) -> tuple[NamedTuple, dict]:
"""Infer with original image.
Args:
frame: np.ndarray, input image
Returns:
predictions: NamedTuple, prediction
frame_meta: Dict, dict with original shape
"""
# getting result include preprocessing, infer, postprocessing for sync infer
predictions = self.core_model(frame)
frame_meta = {"original_shape": frame.shape}
return predictions, frame_meta
def infer_tile(self, frame: np.ndarray) -> tuple[NamedTuple, dict]:
"""Infer by patching full image to tiles.
Args:
frame: np.ndarray - input image
Returns:
Tuple[NamedTuple, Dict]: prediction and original shape
"""
if self.tiler is None:
msg = "Tiler is not set"
raise RuntimeError(msg)
detections = self.tiler(frame)
return detections, {"original_shape": frame.shape}
def __call__(self, input_data: np.ndarray) -> tuple[Any, dict]:
"""Call the ModelWrapper class.
Args:
input_data (np.ndarray): The input image.
Returns:
Tuple[Any, dict]: A tuple containing predictions and the meta information.
"""
if self.tiler is not None:
return self.infer_tile(input_data)
return self.infer(input_data)
|