File size: 3,147 Bytes
45099b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae5318
45099b6
 
 
 
 
d749e18
45099b6
 
 
 
 
d749e18
45099b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from functools import partial
from typing import List, Union

import numpy as np
import numpy.typing as npt
import torch

import plasma.huggingface as hf

from .module.yolov5.yolo_utils.datasets import letterbox
from .module.yolov5.yolo_utils.general import non_max_suppression, scale_coords
from .preprocess import create_batch, process_image
from .utils import (
    DETECTOR_WEIGHT_ID,
    check_image_shape,
    load_yolo_model,
)


class StampDetector:
    def __init__(
        self, model_path: Union[str, None] = None, device: str = "cpu", conf_thres: float = 0.5, iou_thres: float = 0.3
    ) -> None:
        """Create an object for stamp detection"""
        # assert device == "cpu", "Currently only support cpu inference"
        
        checkpoint = hf.download_file(model_path)
        print(model_path)
        print(checkpoint)

        self.device = device
        self.model, self.stride = load_yolo_model(checkpoint, device=device)

        self.img_size = 640
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres

        self.process_func_ = partial(process_image, device=device)

    def __call__(self, image_list: Union[List[npt.NDArray], npt.NDArray]) -> List[npt.NDArray]:
        """Returns a list of bounding boxes [xmin, ymin, xmax, ymax] for each image in image_list
        Each element in the list is a numpy array of shape N x 4

        Args:
            image_list (Union[List[npt.NDArray], npt.NDArray]): input images

        Returns:
            [List[np.ndarray]]: output bounding boxes
        """

        if not isinstance(image_list, (np.ndarray, list)):
            raise TypeError("Invalid Type: Input must be of type list or np.ndarray")

        if len(image_list) > 0:
            check_image_shape(image_list[0])
        else:
            return []
        return self.__detect(image_list)  # type: ignore

    def __detect(self, image_list):  # type: ignore
        """
        Use __call__ method
        """
        batches, indices = create_batch(image_list, set(list(x.shape for x in image_list)))
        predictions = []

        for origin_images in batches:
            images = [letterbox(x, 640, stride=32)[0] for x in origin_images]  # type: ignore
            images = list(map(self.process_func_, images))
            tensor = torch.stack(images).half()

            with torch.no_grad():
                pred = self.model(tensor)[0]
            all_boxes = []
            pred = non_max_suppression(pred, 0.3, 0.30, classes=0, agnostic=1)  # type: ignore

            for idx, det in enumerate(pred):
                if len(det):
                    det[:, :4] = scale_coords(images[idx].shape[1:], det[:, :4], origin_images[0].shape)  # type: ignore
                    det = det[:, :4].round()
                    all_boxes.append(det.cpu().numpy().astype("int").tolist())
                else:
                    all_boxes.append([])

            predictions.extend(all_boxes)

        z = zip(predictions, indices)
        sorted_result = sorted(z, key=lambda x: x[1])
        predictions, _ = zip(*sorted_result)

        return list(predictions)