LayoutLMv3_for_recepits2 / inference.py
mp-02's picture
Update inference.py
d6f6a75 verified
raw
history blame
2.7 kB
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie")
id2label = model.config.id2label
label2id = model.config.label2id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
import json
# Mappa gli ID predetti nelle etichette di classificazione
labels = processor.tokenizer.convert_ids_to_tokens(predicted_ids)
# Funzione per creare l'output JSON in formato CORD-like
def create_json_output(words, labels, boxes):
output = []
for word, label, box in zip(words, labels, boxes):
# Considera solo le etichette rilevanti (escludendo "O")
if label != "O":
output.append({
"text": word,
"category": label, # la categoria predetta dal modello (es. "B-PRODUCT", "B-PRICE", "B-TOTAL")
"bounding_box": box # le coordinate di bounding box per la parola
})
# Converti in JSON
json_output = json.dumps(output, indent=4)
return json_output
def prediction(image):
boxes, words = OCR(image)
# Preprocessa l'immagine e il testo con il processore di LayoutLMv3
encoding = processor(image, words=words, boxes=boxes, return_tensors="pt", padding="max_length", truncation=True)
# Esegui l'inferenza con il modello fine-tuned
with torch.no_grad():
outputs = model(**encoding)
# Ottieni le etichette previste dal modello
logits = outputs.logits
predicted_ids = logits.argmax(-1).squeeze().tolist()
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
probabilities = torch.softmax(outputs.logits, dim=-1)
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
# Crea il JSON usando i risultati ottenuti
json_result = create_json_output(words, labels, boxes)
draw = ImageDraw.Draw(image, "RGBA")
font = ImageFont.load_default()
for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
draw.rectangle(box)
draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")
return image, json_result