|
import torch |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification |
|
from utils import OCR, unnormalize_box |
|
|
|
|
|
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-base-sroie", apply_ocr=False) |
|
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-base-sroie", apply_ocr=False) |
|
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-base-sroie") |
|
|
|
id2label = model.config.id2label |
|
label2id = model.config.label2id |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
|
|
def blur(image, boxes): |
|
image = np.array(image) |
|
for box in boxes: |
|
|
|
blur_x = int(box[0]) |
|
blur_y = int(box[1]) |
|
blur_width = int(box[2]-box[0]) |
|
blur_height = int(box[3]-box[1]) |
|
|
|
roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] |
|
blur_image = cv2.GaussianBlur(roi, (201, 201), 0) |
|
image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image |
|
|
|
return Image.fromarray(image, 'RGB') |
|
|
|
|
|
def prediction(image): |
|
boxes, words = OCR(image) |
|
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True) |
|
offset_mapping = encoding.pop('offset_mapping') |
|
|
|
for k, v in encoding.items(): |
|
encoding[k] = v.to(device) |
|
|
|
outputs = model(**encoding) |
|
|
|
predictions = outputs.logits.argmax(-1).squeeze().tolist() |
|
token_boxes = encoding.bbox.squeeze().tolist() |
|
probabilities = torch.softmax(outputs.logits, dim=-1) |
|
confidence_scores = probabilities.max(-1).values.squeeze().tolist() |
|
|
|
inp_ids = encoding.input_ids.squeeze().tolist() |
|
inp_words = [tokenizer.decode(i) for i in inp_ids] |
|
|
|
width, height = image.size |
|
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0 |
|
|
|
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]] |
|
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]] |
|
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]] |
|
true_words = [] |
|
|
|
for id, i in enumerate(inp_words): |
|
if not is_subword[id]: |
|
true_words.append(i) |
|
else: |
|
true_words[-1] = true_words[-1]+i |
|
|
|
true_predictions = true_predictions[1:-1] |
|
true_boxes = true_boxes[1:-1] |
|
true_words = true_words[1:-1] |
|
true_confidence_scores = true_confidence_scores[1:-1] |
|
|
|
for i, j in enumerate(true_confidence_scores): |
|
if j < 0.5: |
|
true_predictions[i] = "O" |
|
|
|
d = {} |
|
for id, i in enumerate(true_predictions): |
|
|
|
if i != "O": |
|
i = i[2:] |
|
if i not in d.keys(): |
|
d[i] = true_words[id] |
|
else: |
|
d[i] = d[i] + ", " + true_words[id] |
|
d = {k: v.strip() for (k, v) in d.items()} |
|
|
|
if "O" in d: d.pop("O") |
|
if "TOTAL" in d: d.pop("TOTAL") |
|
|
|
blur_boxes = [] |
|
for prediction, box in zip(true_predictions, true_boxes): |
|
if prediction != 'O' and prediction != 'S-TOTAL': |
|
blur_boxes.append(box) |
|
|
|
image = (blur(image, blur_boxes)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return d, image |
|
|
|
|