File size: 3,788 Bytes
45099b6
 
 
 
 
 
 
 
 
adc3ff1
45099b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adc3ff1
45099b6
adc3ff1
45099b6
 
 
 
adc3ff1
 
 
45099b6
 
 
 
 
 
1c6f7e9
45099b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from typing import List, Union

import numpy as np
import numpy.typing as npt

from .detector import StampDetector
from .module.unet import *
from .preprocess import create_batch
from .utils import REMOVER_WEIGHT_ID, check_image_shape, download_weight


class StampRemover:
    def __init__(
        self, detection_weight: Union[str, None] = None, removal_weight: Union[str, None] = None, device: str = "cpu"
    ):
        """Create an object to remove stamps from document images"""

        # assert device == "cpu", "Currently only support cpu inference"

        if removal_weight is None:
            if not os.path.exists("tmp/"):
                os.makedirs("tmp/", exist_ok=True)
            removal_weight = os.path.join("tmp", "stamp_remover.pkl")

            print("Downloading stamp remover weight from google drive")
            download_weight(REMOVER_WEIGHT_ID, output=removal_weight)
            print(f"Finished downloading. Weight is saved at {removal_weight}")

        try:
            self.remover = UnetInference(removal_weight)  # type: ignore
        except Exception as e:
            print(e)
            print("There is something wrong when loading remover weight")
            print(
                "Please make sure you provide the correct path to the weight"
                "or mannually download the weight at"
                f"https://drive.google.com/file/d/{REMOVER_WEIGHT_ID}/view?usp=sharing"
            )
            raise FileNotFoundError()

        self.detector = StampDetector(detection_weight, device="cpu")
        self.padding = 3

    def __call__(self, image_list: Union[List[npt.NDArray], npt.NDArray], batch_size: int = 16) -> List[npt.NDArray]:
        """Detect and remove stamps from document images

        Args:
            image_list (Union[List[npt.NDArray], npt.NDArray]): list of input images
            batch_size (int, optional): Defaults to 16.

        Returns:
            List[np.ndarray]: Input images with stamps removed
        """
        if not isinstance(image_list, (np.ndarray, list)):
            raise TypeError("Invalid Type: Input must be of type list or np.ndarray")

        if len(image_list) > 0:
            check_image_shape(image_list[0])
        else:
            return []
        return self.__batch_removing(image_list, batch_size)  # type:ignore

    def __batch_removing(self, image_list, batch_size=16):  # type: ignore
        new_pages = []

        shapes = set(list(x.shape for x in image_list))
        images_batch, indices = create_batch(image_list, shapes, batch_size)
        # num_batch = len(image_list) // batch_size
        detection_predictions = []
        for batch in images_batch:
            if len(batch):
                detection_predictions.extend(self.detector(batch))
        z = zip(detection_predictions, indices)
        sorted_result = sorted(z, key=lambda x: x[1])
        detection_predictions, _ = zip(*sorted_result)
        for idx, page_boxes in enumerate(detection_predictions):
            page_img = image_list[idx]
            h, w, c = page_img.shape
            for box in page_boxes:
                x_min, y_min, x_max, y_max = box[:4]
                stamp_area = page_img[
                    max(y_min - self.padding, 0) : min(y_max + self.padding, h),
                    max(x_min - self.padding, 0) : min(x_max + self.padding, w),
                ]
                stamp_area = self.remover([stamp_area])  # type:ignore

                page_img[
                    max(y_min - self.padding, 0) : min(y_max + self.padding, h),
                    max(x_min - self.padding, 0) : min(x_max + self.padding, w),
                    :,
                ] = stamp_area[0]
            new_pages.append(page_img)
        return new_pages