ERNIE-Layout_Pytorch
Model type: ERNIE-Layout
Repository: source code: an unofficial ERNIE-Layout implementation in Pytorch
Converted from: PaddlePaddle/ernie-layoutx-base-uncased
The ERNIE-Layout-Pytorch model is initially released by PaddleNLP. To make Pytorch users easy to use, the model has been converted into PyTorch format with the tools/convert2torch.py script. Please feel free to make any changes you need. For more details and use cases, please check the repo.
A Quick Example
import torch
from PIL import Image
import torch.nn.functional as F
from networks import ErnieLayoutConfig, ErnieLayoutForQuestionAnswering, \
ErnieLayoutProcessor, ErnieLayoutTokenizerFast
from transformers.models.layoutlmv3 import LayoutLMv3ImageProcessor
pretrain_torch_model_or_path = "Norm/ERNIE-Layout-Pytorch"
doc_imag_path = "./dummy_input.jpeg"
context = ['This is an example sequence', 'All ocr boxes are inserted into this list']
layout = [[381, 91, 505, 115], [738, 96, 804, 122]] # make sure all boxes are normalized between 0 - 1000
pil_image = Image.open(doc_imag_path).convert("RGB")
# initialize tokenizer
tokenizer = ErnieLayoutTokenizerFast.from_pretrained(pretrained_model_name_or_path=pretrain_torch_model_or_path)
# initialize feature extractor
feature_extractor = LayoutLMv3ImageProcessor(apply_ocr=False)
processor = ErnieLayoutProcessor(image_processor=feature_extractor, tokenizer=tokenizer)
# Tokenize context & questions
question = "what is it?"
encoding = processor(pil_image, question, context, boxes=layout, return_tensors="pt")
# dummy answer start && end index
start_positions = torch.tensor([6])
end_positions = torch.tensor([12])
# initialize config
config = ErnieLayoutConfig.from_pretrained(pretrained_model_name_or_path=pretrain_torch_model_or_path)
config.num_classes = 2 # start and end
# initialize ERNIE for VQA
model = ErnieLayoutForQuestionAnswering.from_pretrained(
pretrained_model_name_or_path=pretrain_torch_model_or_path,
config=config,
)
output = model(**encoding, start_positions=start_positions, end_positions=end_positions)
# decode output
start_max = torch.argmax(F.softmax(output.start_logits, dim=-1))
end_max = torch.argmax(F.softmax(output.end_logits, dim=-1)) + 1 # add one ##because of python list indexing
answer = tokenizer.decode(encoding.input_ids[0][start_max: end_max])
print(answer)
- Downloads last month
- 914