File size: 3,381 Bytes
9690de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import io
import pandas as pd
import plotly.express as px
import streamlit as st

import torch
import torch.nn.functional as F
from easyocr import Reader
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from transformers import (
    LayoutLMv3FeatureExtractor, 
    LayoutLMv3TokenizerFast, 
    LayoutLMv3Processor, 
    LayoutLMv3ForSequenceClassification
    )

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MICROSOFT_MODEL_NAME = "microsoft/layoutlmv3-base"
MODEL_NAME = "curiousily/layoutlmv3-financial-document-classification"

def creat_bounding_box(bbox_data, width_scale: float, height_scale: float):
    xs = []
    ys = []
    for x, y in bbox_data:
        xs.append(x)
        ys.append(y)

        left = int(min(xs) * width_scale)
        top = int(min(ys) * height_scale)
        right = int(max(xs) * width_scale)
        bottom = int(max(ys) * height_scale)
    
    return [left, top, right, bottom]

@st.experimental_singleton
def create_ocr_reader():
    return Reader(["en"])

@st.experimental_singleton
def create_processor():
    feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
    tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MICROSOFT_MODEL_NAME)
    return LayoutLMv3Processor(feature_extractor, tokenizer)

@st.experimental_singleton
def create_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
    return model.eval().to(DEVICE)

def predict(
    image: Image, 
    reader: Reader, 
    processor: LayoutLMv3Processor, 
    model: LayoutLMv3ForSequenceClassification
):

    ocr_result = reader.readtext(image)

    width, height = image.size
    width_scale = 1000 / width
    height_scale = 1000 / height

    words = []
    boxes = []

    for bbox, word, confidence in ocr_result:
        words.append(word)
        boxes.append(creat_bounding_box(bbox, width_scale, height_scale))

    encoding = processor(
        image, 
        words, 
        boxes=boxes, 
        max_length=512, 
        padding="max_length", 
        truncation=True, 
        return_tensors="pt"
        )

    with torch.inference_mode():
        output = model(
            input_ids=encoding["input_ids"].to(DEVICE),
            attention_mask=encoding["attention_mask"].to(DEVICE),
            bbox=encoding["bbox"].to(DEVICE),
            pixel_values=encoding["pixel_values"].to(DEVICE),
        )

    logits = output.logits
    predicted_class = logits.argmax()
    probabilities = F.softmax(logits, dim=-1).flatten().tolist()

    return predicted_class.detach().item(), probabilities 


reader = create_ocr_reader()
processor = create_processor()
model = create_model()

uploaded_file = st.file_uploader("Upload Document Image", ["jpg", "png"])
if uploaded_file is not None:
    bytes_data = io.BytesIO(uploaded_file.getvalue())
    image = Image.open(bytes_data)
    st.image(image, "Your Document")

    predicted_class, probabilities = predict(image, reader, processor, model)
    predicted_label = model.config.id2label[predicted_class]
    
    st.markdown(f"Predicted document type: **{predicted_label}**")

    df_predictions = pd.DataFrame(
        {"Document": list(model.config.id2label.values()), "Confidence": probabilities}
    )

    fig = px.bar(df_predictions, x="Document", y="Confidence")
    st.plotly_chart(fig, use_container_width=True)