File size: 3,770 Bytes
1ceb840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from enum import Enum
from PIL import Image


def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))


class ExplicitEnum(Enum):
    """
    Enum with more explicit error message for missing values or getting all options
    """

    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )

    @classmethod
    def options(cls):
        return list(cls._value2member_map_.keys())


class InferenceMethod(ExplicitEnum):
    """All the implemented inference methods"""

    FIRST = "first"  # check on data setup
    SECOND = "second"  # robustness for multipage categories
    LAST = "last"  # robustness for multipage categories

    GRID = "grid"  # create a grid (equal spaced/ OCR density based)
    # downscale resolution ; might be fine for classification

    MAX_CONFIDENCE = "max_confidence"  # page with highest confidence overall
    SOFT_VOTING = "soft_voting"  # sum conf/N   -> logits/softmax
    HARD_VOTING = "hard_voting"  # count votes

    @property
    def scope(self):
        if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]:
            return "sample"
        if self in [InferenceMethod.GRID]:
            return "sample-grid"  # single image yet transformation required
        else:
            return "iter"

    def get_page_scope(self, pages):
        if self.scope == "iter":
            return pages
        if self == InferenceMethod.GRID:
            try:
                return equal_image_grid(pages)
            except Exception as e:
                return pages[-1]  # last to not positively bias
        if self == InferenceMethod.FIRST:
            return pages[0]
        if self == InferenceMethod.SECOND:
            if len(pages) > 1:
                return pages[1]
            return pages[0]  # backoff
        if self == InferenceMethod.LAST:
            return pages[-1]

    def apply_decision_strategy(self, page_logits):
        """
        page logits is of shape [NUM_PAGES x CLASSES]
        """
        if self == InferenceMethod.MAX_CONFIDENCE:
            index = page_logits.argmax()  # tensor with one number
            indices = unravel_index(index, page_logits.shape)
            print(f"The page which is max confident: {indices[0]}")
            return indices[-1]
        if self == InferenceMethod.HARD_VOTING:
            return page_logits.argmax(-1).max()
        if self == InferenceMethod.SOFT_VOTING:
            return page_logits.mean(0).argmax(-1)


def equal_image_grid(images):
    def compute_grid(n, max_cols=6):
        equalDivisor = int(n**0.5)
        cols = min(equalDivisor, max_cols)
        rows = equalDivisor
        if rows * cols >= n:
            return rows, cols
        cols += 1
        if rows * cols >= n:
            return rows, cols
        while rows * cols < n:
            rows += 1
        return rows, cols

    # assert len(images) == rows*cols
    rows, cols = compute_grid(len(images))

    # rescaling to min width [height padding]
    images = [im for im in images if (im.height > 0) and (im.width > 0)]  # could be NA

    min_width = min(im.width for im in images)
    images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images]

    w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images])

    grid = Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(images):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid