|
from copy import deepcopy |
|
from typing import Dict, Union, Optional, List, Tuple |
|
|
|
import torch |
|
from torch import TensorType |
|
from transformers import DonutImageProcessor, DonutProcessor |
|
from transformers.image_processing_utils import BatchFeature |
|
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \ |
|
valid_images, to_numpy_array |
|
import numpy as np |
|
from PIL import Image |
|
import PIL |
|
from surya.settings import settings |
|
|
|
|
|
def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT): |
|
processor = OrderImageProcessor.from_pretrained(checkpoint) |
|
processor.size = settings.ORDER_IMAGE_SIZE |
|
box_size = 1024 |
|
max_tokens = 256 |
|
processor.token_sep_id = max_tokens + box_size + 1 |
|
processor.token_pad_id = max_tokens + box_size + 2 |
|
processor.max_boxes = settings.ORDER_MAX_BOXES - 1 |
|
processor.box_size = {"height": box_size, "width": box_size} |
|
return processor |
|
|
|
|
|
class OrderImageProcessor(DonutImageProcessor): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.patch_size = kwargs.get("patch_size", (4, 4)) |
|
|
|
def process_inner(self, images: List[np.ndarray]): |
|
images = [img.transpose(2, 0, 1) for img in images] |
|
|
|
assert images[0].shape[0] == 3 |
|
|
|
|
|
images = [img.astype(np.float32) for img in images] |
|
|
|
|
|
images = [ |
|
self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST) |
|
for img in images |
|
] |
|
images = [ |
|
self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) |
|
for img in images |
|
] |
|
|
|
return images |
|
|
|
def process_boxes(self, boxes): |
|
padded_boxes = [] |
|
box_masks = [] |
|
box_counts = [] |
|
for b in boxes: |
|
|
|
padded_b = deepcopy(b) |
|
padded_b.append([self.token_sep_id] * 4) |
|
padded_boxes.append(padded_b) |
|
|
|
max_boxes = max(len(b) for b in padded_boxes) |
|
for i in range(len(padded_boxes)): |
|
pad_len = max_boxes - len(padded_boxes[i]) |
|
box_len = len(padded_boxes[i]) |
|
box_mask = [0] * pad_len + [1] * box_len |
|
padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i] |
|
padded_boxes[i] = padded_box |
|
box_masks.append(box_mask) |
|
box_counts.append([pad_len, max_boxes]) |
|
|
|
return padded_boxes, box_masks, box_counts |
|
|
|
def resize_img_and_boxes(self, img, boxes): |
|
orig_dim = img.size |
|
new_size = (self.size["width"], self.size["height"]) |
|
img.thumbnail(new_size, Image.Resampling.LANCZOS) |
|
img = img.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
img = np.asarray(img, dtype=np.uint8) |
|
|
|
width, height = orig_dim |
|
box_width, box_height = self.box_size["width"], self.box_size["height"] |
|
for box in boxes: |
|
|
|
box[0] = box[0] / width * box_width |
|
box[1] = box[1] / height * box_height |
|
box[2] = box[2] / width * box_width |
|
box[3] = box[3] / height * box_height |
|
|
|
if box[0] < 0: |
|
box[0] = 0 |
|
if box[1] < 0: |
|
box[1] = 0 |
|
if box[2] > box_width: |
|
box[2] = box_width |
|
if box[3] > box_height: |
|
box[3] = box_height |
|
|
|
return img, boxes |
|
|
|
def preprocess( |
|
self, |
|
images: ImageInput, |
|
boxes: List[List[int]], |
|
do_resize: bool = None, |
|
size: Dict[str, int] = None, |
|
resample: PILImageResampling = None, |
|
do_thumbnail: bool = None, |
|
do_align_long_axis: bool = None, |
|
do_pad: bool = None, |
|
random_padding: bool = False, |
|
do_rescale: bool = None, |
|
rescale_factor: float = None, |
|
do_normalize: bool = None, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
**kwargs, |
|
) -> PIL.Image.Image: |
|
images = make_list_of_images(images) |
|
|
|
if not valid_images(images): |
|
raise ValueError( |
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
|
"torch.Tensor, tf.Tensor or jax.ndarray." |
|
) |
|
|
|
new_images = [] |
|
new_boxes = [] |
|
for img, box in zip(images, boxes): |
|
if len(box) > self.max_boxes: |
|
raise ValueError(f"Too many boxes, max is {self.max_boxes}") |
|
img, box = self.resize_img_and_boxes(img, box) |
|
new_images.append(img) |
|
new_boxes.append(box) |
|
|
|
images = new_images |
|
boxes = new_boxes |
|
|
|
|
|
images = [np.array(image) for image in images] |
|
|
|
images = self.process_inner(images) |
|
boxes, box_mask, box_counts = self.process_boxes(boxes) |
|
data = { |
|
"pixel_values": images, |
|
"input_boxes": boxes, |
|
"input_boxes_mask": box_mask, |
|
"input_boxes_counts": box_counts, |
|
} |
|
return BatchFeature(data=data, tensor_type=return_tensors) |