DonutMathHWP / utils.py
sooooner's picture
.
ce2c619
import os
import re
import torch
import requests
import numpy as np
import PIL.Image
import PIL.ImageOps
from PIL import Image
from typing import Union
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
Returns:
`PIL.Image.Image`:
A PIL Image.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
def aspect_ratio_preserving_resize_and_crop(image, target_width, target_height):
width, height = image.size
width_ratio = width / target_width
height_ratio = height / target_height
if width > target_width and height > target_height:
if width_ratio > height_ratio:
new_width = target_width
new_height = int(new_width / (width / height))
else:
new_height = target_height
new_width = int(new_height * (width / height))
elif width > target_width:
new_width = target_width
new_height = int(new_width / (width / height))
elif height > target_height:
new_height = target_height
new_width = int(new_height * (width / height))
else:
new_width, new_height = width, height
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
padded_image = Image.new("RGB", (target_width, target_height), (255, 255, 255))
offset_x = (target_width - new_width) // 2
offset_y = (target_height - new_height) // 2
padded_image.paste(resized_image, (offset_x, offset_y))
return padded_image
class Image2Text:
def __init__(self, model_path, hf_token, device, max_length=1024):
self.device = device
self.hf_token = hf_token
self.model_path = model_path
self.max_length = max_length
self.model, self.processor = self.load_model(self.model_path)
self.decoder_input_ids = torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device)
def load_model(self, model_path):
config = VisionEncoderDecoderConfig.from_pretrained(model_path, token=self.hf_token)
processor = DonutProcessor.from_pretrained(model_path, token=self.hf_token)
model = VisionEncoderDecoderModel.from_pretrained(model_path, config=config, token=self.hf_token).to(self.device)
model.eval()
return model, processor
def load_img(self, inputs, width=480, height=480):
images = [load_image(input_) for input_ in inputs]
images = [aspect_ratio_preserving_resize_and_crop(image, target_width=width, target_height=height) for image in images]
imgs = self.processor([image.convert("RGB") for image in images], return_tensors="pt", size=(width, height)).pixel_values
pixel_values = imgs.to(self.device)
return pixel_values
def generate(self, pixel_values, num_beams):
outputs = self.model.generate(
pixel_values,
decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1),
max_length=self.max_length,
early_stopping=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=num_beams,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
return outputs
def postprocessing(self, outputs):
seqs = self.processor.batch_decode(outputs.sequences)
seqs = [seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "") for seq in seqs]
seqs = [re.sub(r"<.*?>", "", seq, count=1).strip() for seq in seqs]
seqs = [self.processor.token2json(seq) for seq in seqs]
contents = []
for seq in seqs:
try:
content = seq['content']
except:
content = seq['text_sequence']
contents.append('\n'.join(content.split('[newline]')))
return contents
def get_text(self, img_path, num_beams=4):
pixel_values = self.load_img(img_path)
outputs = self.generate(pixel_values, num_beams)
contents = self.postprocessing(outputs)
return contents