src / inference_methods.py
bdpc's picture
Upload 9 files
1ceb840
from enum import Enum
from PIL import Image
def unravel_index(index, shape):
out = []
for dim in reversed(shape):
out.append(index % dim)
index = index // dim
return tuple(reversed(out))
class ExplicitEnum(Enum):
"""
Enum with more explicit error message for missing values or getting all options
"""
@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
@classmethod
def options(cls):
return list(cls._value2member_map_.keys())
class InferenceMethod(ExplicitEnum):
"""All the implemented inference methods"""
FIRST = "first" # check on data setup
SECOND = "second" # robustness for multipage categories
LAST = "last" # robustness for multipage categories
GRID = "grid" # create a grid (equal spaced/ OCR density based)
# downscale resolution ; might be fine for classification
MAX_CONFIDENCE = "max_confidence" # page with highest confidence overall
SOFT_VOTING = "soft_voting" # sum conf/N -> logits/softmax
HARD_VOTING = "hard_voting" # count votes
@property
def scope(self):
if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]:
return "sample"
if self in [InferenceMethod.GRID]:
return "sample-grid" # single image yet transformation required
else:
return "iter"
def get_page_scope(self, pages):
if self.scope == "iter":
return pages
if self == InferenceMethod.GRID:
try:
return equal_image_grid(pages)
except Exception as e:
return pages[-1] # last to not positively bias
if self == InferenceMethod.FIRST:
return pages[0]
if self == InferenceMethod.SECOND:
if len(pages) > 1:
return pages[1]
return pages[0] # backoff
if self == InferenceMethod.LAST:
return pages[-1]
def apply_decision_strategy(self, page_logits):
"""
page logits is of shape [NUM_PAGES x CLASSES]
"""
if self == InferenceMethod.MAX_CONFIDENCE:
index = page_logits.argmax() # tensor with one number
indices = unravel_index(index, page_logits.shape)
print(f"The page which is max confident: {indices[0]}")
return indices[-1]
if self == InferenceMethod.HARD_VOTING:
return page_logits.argmax(-1).max()
if self == InferenceMethod.SOFT_VOTING:
return page_logits.mean(0).argmax(-1)
def equal_image_grid(images):
def compute_grid(n, max_cols=6):
equalDivisor = int(n**0.5)
cols = min(equalDivisor, max_cols)
rows = equalDivisor
if rows * cols >= n:
return rows, cols
cols += 1
if rows * cols >= n:
return rows, cols
while rows * cols < n:
rows += 1
return rows, cols
# assert len(images) == rows*cols
rows, cols = compute_grid(len(images))
# rescaling to min width [height padding]
images = [im for im in images if (im.height > 0) and (im.width > 0)] # could be NA
min_width = min(im.width for im in images)
images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images]
w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images])
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid