from enum import Enum from PIL import Image def unravel_index(index, shape): out = [] for dim in reversed(shape): out.append(index % dim) index = index // dim return tuple(reversed(out)) class ExplicitEnum(Enum): """ Enum with more explicit error message for missing values or getting all options """ @classmethod def _missing_(cls, value): raise ValueError( f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" ) @classmethod def options(cls): return list(cls._value2member_map_.keys()) class InferenceMethod(ExplicitEnum): """All the implemented inference methods""" FIRST = "first" # check on data setup SECOND = "second" # robustness for multipage categories LAST = "last" # robustness for multipage categories GRID = "grid" # create a grid (equal spaced/ OCR density based) # downscale resolution ; might be fine for classification MAX_CONFIDENCE = "max_confidence" # page with highest confidence overall SOFT_VOTING = "soft_voting" # sum conf/N -> logits/softmax HARD_VOTING = "hard_voting" # count votes @property def scope(self): if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]: return "sample" if self in [InferenceMethod.GRID]: return "sample-grid" # single image yet transformation required else: return "iter" def get_page_scope(self, pages): if self.scope == "iter": return pages if self == InferenceMethod.GRID: try: return equal_image_grid(pages) except Exception as e: return pages[-1] # last to not positively bias if self == InferenceMethod.FIRST: return pages[0] if self == InferenceMethod.SECOND: if len(pages) > 1: return pages[1] return pages[0] # backoff if self == InferenceMethod.LAST: return pages[-1] def apply_decision_strategy(self, page_logits): """ page logits is of shape [NUM_PAGES x CLASSES] """ if self == InferenceMethod.MAX_CONFIDENCE: index = page_logits.argmax() # tensor with one number indices = unravel_index(index, page_logits.shape) print(f"The page which is max confident: {indices[0]}") return indices[-1] if self == InferenceMethod.HARD_VOTING: return page_logits.argmax(-1).max() if self == InferenceMethod.SOFT_VOTING: return page_logits.mean(0).argmax(-1) def equal_image_grid(images): def compute_grid(n, max_cols=6): equalDivisor = int(n**0.5) cols = min(equalDivisor, max_cols) rows = equalDivisor if rows * cols >= n: return rows, cols cols += 1 if rows * cols >= n: return rows, cols while rows * cols < n: rows += 1 return rows, cols # assert len(images) == rows*cols rows, cols = compute_grid(len(images)) # rescaling to min width [height padding] images = [im for im in images if (im.height > 0) and (im.width > 0)] # could be NA min_width = min(im.width for im in images) images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images] w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images]) grid = Image.new("RGB", size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(images): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid