Spaces:
Runtime error
Runtime error
import os | |
import re | |
import torch | |
import requests | |
import numpy as np | |
import PIL.Image | |
import PIL.ImageOps | |
from PIL import Image | |
from typing import Union | |
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig | |
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: | |
""" | |
Loads `image` to a PIL Image. | |
Args: | |
image (`str` or `PIL.Image.Image`): | |
The image to convert to the PIL Image format. | |
Returns: | |
`PIL.Image.Image`: | |
A PIL Image. | |
""" | |
if isinstance(image, str): | |
if image.startswith("http://") or image.startswith("https://"): | |
image = PIL.Image.open(requests.get(image, stream=True).raw) | |
elif os.path.isfile(image): | |
image = PIL.Image.open(image) | |
else: | |
raise ValueError( | |
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" | |
) | |
elif isinstance(image, PIL.Image.Image): | |
image = image | |
else: | |
raise ValueError( | |
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." | |
) | |
image = PIL.ImageOps.exif_transpose(image) | |
image = image.convert("RGB") | |
return image | |
def aspect_ratio_preserving_resize_and_crop(image, target_width, target_height): | |
width, height = image.size | |
width_ratio = width / target_width | |
height_ratio = height / target_height | |
if width > target_width and height > target_height: | |
if width_ratio > height_ratio: | |
new_width = target_width | |
new_height = int(new_width / (width / height)) | |
else: | |
new_height = target_height | |
new_width = int(new_height * (width / height)) | |
elif width > target_width: | |
new_width = target_width | |
new_height = int(new_width / (width / height)) | |
elif height > target_height: | |
new_height = target_height | |
new_width = int(new_height * (width / height)) | |
else: | |
new_width, new_height = width, height | |
resized_image = image.resize((new_width, new_height), Image.LANCZOS) | |
padded_image = Image.new("RGB", (target_width, target_height), (255, 255, 255)) | |
offset_x = (target_width - new_width) // 2 | |
offset_y = (target_height - new_height) // 2 | |
padded_image.paste(resized_image, (offset_x, offset_y)) | |
return padded_image | |
class Image2Text: | |
def __init__(self, model_path, hf_token, device, max_length=1024): | |
self.device = device | |
self.hf_token = hf_token | |
self.model_path = model_path | |
self.max_length = max_length | |
self.model, self.processor = self.load_model(self.model_path) | |
self.decoder_input_ids = torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device) | |
def load_model(self, model_path): | |
config = VisionEncoderDecoderConfig.from_pretrained(model_path, token=self.hf_token) | |
processor = DonutProcessor.from_pretrained(model_path, token=self.hf_token) | |
model = VisionEncoderDecoderModel.from_pretrained(model_path, config=config, token=self.hf_token).to(self.device) | |
model.eval() | |
return model, processor | |
def load_img(self, inputs, width=480, height=480): | |
images = [load_image(input_) for input_ in inputs] | |
images = [aspect_ratio_preserving_resize_and_crop(image, target_width=width, target_height=height) for image in images] | |
imgs = self.processor([image.convert("RGB") for image in images], return_tensors="pt", size=(width, height)).pixel_values | |
pixel_values = imgs.to(self.device) | |
return pixel_values | |
def generate(self, pixel_values, num_beams): | |
outputs = self.model.generate( | |
pixel_values, | |
decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1), | |
max_length=self.max_length, | |
early_stopping=True, | |
pad_token_id=self.processor.tokenizer.pad_token_id, | |
eos_token_id=self.processor.tokenizer.eos_token_id, | |
use_cache=True, | |
num_beams=num_beams, | |
bad_words_ids=[[self.processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
return outputs | |
def postprocessing(self, outputs): | |
seqs = self.processor.batch_decode(outputs.sequences) | |
seqs = [seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "") for seq in seqs] | |
seqs = [re.sub(r"<.*?>", "", seq, count=1).strip() for seq in seqs] | |
seqs = [self.processor.token2json(seq) for seq in seqs] | |
contents = [] | |
for seq in seqs: | |
try: | |
content = seq['content'] | |
except: | |
content = seq['text_sequence'] | |
contents.append('\n'.join(content.split('[newline]'))) | |
return contents | |
def get_text(self, img_path, num_beams=4): | |
pixel_values = self.load_img(img_path) | |
outputs = self.generate(pixel_values, num_beams) | |
contents = self.postprocessing(outputs) | |
return contents | |