|
from enum import Enum |
|
from PIL import Image |
|
|
|
|
|
def unravel_index(index, shape): |
|
out = [] |
|
for dim in reversed(shape): |
|
out.append(index % dim) |
|
index = index // dim |
|
return tuple(reversed(out)) |
|
|
|
|
|
class ExplicitEnum(Enum): |
|
""" |
|
Enum with more explicit error message for missing values or getting all options |
|
""" |
|
|
|
@classmethod |
|
def _missing_(cls, value): |
|
raise ValueError( |
|
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" |
|
) |
|
|
|
@classmethod |
|
def options(cls): |
|
return list(cls._value2member_map_.keys()) |
|
|
|
|
|
class InferenceMethod(ExplicitEnum): |
|
"""All the implemented inference methods""" |
|
|
|
FIRST = "first" |
|
SECOND = "second" |
|
LAST = "last" |
|
|
|
GRID = "grid" |
|
|
|
|
|
MAX_CONFIDENCE = "max_confidence" |
|
SOFT_VOTING = "soft_voting" |
|
HARD_VOTING = "hard_voting" |
|
|
|
@property |
|
def scope(self): |
|
if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]: |
|
return "sample" |
|
if self in [InferenceMethod.GRID]: |
|
return "sample-grid" |
|
else: |
|
return "iter" |
|
|
|
def get_page_scope(self, pages): |
|
if self.scope == "iter": |
|
return pages |
|
if self == InferenceMethod.GRID: |
|
try: |
|
return equal_image_grid(pages) |
|
except Exception as e: |
|
return pages[-1] |
|
if self == InferenceMethod.FIRST: |
|
return pages[0] |
|
if self == InferenceMethod.SECOND: |
|
if len(pages) > 1: |
|
return pages[1] |
|
return pages[0] |
|
if self == InferenceMethod.LAST: |
|
return pages[-1] |
|
|
|
def apply_decision_strategy(self, page_logits): |
|
""" |
|
page logits is of shape [NUM_PAGES x CLASSES] |
|
""" |
|
if self == InferenceMethod.MAX_CONFIDENCE: |
|
index = page_logits.argmax() |
|
indices = unravel_index(index, page_logits.shape) |
|
print(f"The page which is max confident: {indices[0]}") |
|
return indices[-1] |
|
if self == InferenceMethod.HARD_VOTING: |
|
return page_logits.argmax(-1).max() |
|
if self == InferenceMethod.SOFT_VOTING: |
|
return page_logits.mean(0).argmax(-1) |
|
|
|
|
|
def equal_image_grid(images): |
|
def compute_grid(n, max_cols=6): |
|
equalDivisor = int(n**0.5) |
|
cols = min(equalDivisor, max_cols) |
|
rows = equalDivisor |
|
if rows * cols >= n: |
|
return rows, cols |
|
cols += 1 |
|
if rows * cols >= n: |
|
return rows, cols |
|
while rows * cols < n: |
|
rows += 1 |
|
return rows, cols |
|
|
|
|
|
rows, cols = compute_grid(len(images)) |
|
|
|
|
|
images = [im for im in images if (im.height > 0) and (im.width > 0)] |
|
|
|
min_width = min(im.width for im in images) |
|
images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images] |
|
|
|
w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images]) |
|
|
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(images): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|