File size: 5,385 Bytes
0ae43ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d40b14
0ae43ca
 
 
0552a9b
 
0ae43ca
0552a9b
0ae43ca
0552a9b
 
 
 
 
0ae43ca
 
0552a9b
 
0ae43ca
 
 
e59fb5f
0ae43ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d40b14
0ae43ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12982b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import streamlit as st
import warnings
warnings.simplefilter("ignore", UserWarning)

from uuid import uuid4
from laia.scripts.htr.decode_ctc import run as decode
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
import sys
from tempfile import NamedTemporaryFile, mkdtemp
from pathlib import Path
from contextlib import redirect_stdout
import re
from PIL import Image
from bidi.algorithm import get_display
import multiprocessing
from ultralytics import YOLO
import cv2
import numpy as np
import pandas as pd
import logging

# Configure logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)

# Load YOLOv8 model
model = YOLO('model.pt')
images = Path(mkdtemp())
DEFAULT_HEIGHT = 128
TEXT_DIRECTION = "RTL"
NUM_WORKERS = multiprocessing.cpu_count()

# Regex pattern for extracting results
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)"  # For line
TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")

def get_width(image, height=DEFAULT_HEIGHT):
    aspect_ratio = image.width / image.height
    return height * aspect_ratio

def predict(model_name, input_img):
    model_dir = 'pylaia-samaritan_v1'
    temperature = 2.0
    batch_size = 1

    weights_path = f"{model_dir}/weights.ckpt"
    syms_path = f"{model_dir}/syms.txt"
    language_model_params = {"language_model_weight": 1.0}
    use_language_model = True
    if use_language_model:
        language_model_params.update({
            "language_model_path": f"{model_dir}/language_model.binary",
            "lexicon_path": f"{model_dir}/lexicon.txt",
            "tokens_path": f"{model_dir}/tokens.txt",
        })

    common_args = CommonArgs(
        checkpoint="weights.ckpt",
        train_path=f"{model_dir}",
        experiment_dirname="",
    )

    data_args = DataArgs(batch_size=batch_size, color_mode="L")
    trainer_args = TrainerArgs(progress_bar_refresh_rate=0)
    decode_args = DecodeArgs(
        include_img_ids=True,
        join_string="",
        convert_spaces=True,
        print_line_confidence_scores=True,
        print_word_confidence_scores=False,
        temperature=temperature,
        use_language_model=use_language_model,
        **language_model_params,
    )

    with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
        image_id = uuid4()
        input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
        input_img.save(f"{images}/{image_id}.jpg")
        Path(img_list.name).write_text("\n".join([str(image_id)]))

        with redirect_stdout(open(pred_stdout.name, mode="w")):
            decode(
                syms=str(syms_path),
                img_list=img_list.name,
                img_dirs=[str(images)],
                common=common_args,
                data=data_args,
                trainer=trainer_args,
                decode=decode_args,
                num_workers=1,
            )
            sys.stdout.flush()
        predictions = Path(pred_stdout.name).read_text().strip().splitlines()

    _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
    if TEXT_DIRECTION == "RTL":
        return input_img, {"text": get_display(text), "score": score}
    else:
        return input_img, {"text": text, "score": score}

def process_image(image):
    # Perform inference on an image, select textline only
    results = model(image, classes=0)

    img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    boxes = results[0].boxes.xyxy.tolist()
    boxes.sort(key=lambda x: x[1])

    bboxes = []
    polygons = []
    texts = []

    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = map(int, box)
        crop_img = img_cv2[y1:y2, x1:x2]
        crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))

        # Recognize text using PyLaia model
        predicted = predict('pylaia-samaritan_v1', crop_pil)
        texts.append(predicted[1]["text"])

        bboxes.append((x1, y1, x2, y2))
        polygons.append(f"Line {i+1}: {[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]}")

        # Draw bounding box
        cv2.rectangle(img_cv2, (x1, y1), (x2, y2), (0, 255, 0), 2)

    # Convert image back to RGB for display in Streamlit
    img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)

    # Combine polygons and texts into a DataFrame for table display
    table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts})
    return Image.fromarray(img_result), table_data

def segment_and_recognize(image):
    segmented_image, table_data = process_image(image)
    return segmented_image, table_data

# Streamlit app layout
st.title("YOLOv8 Text Line Segmentation & PyLaia Text Recognition")

# File uploader
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])

# Process the image if uploaded
if uploaded_image is not None:
    image = Image.open(uploaded_image)

    if st.button("Segment and Recognize"):
        # Perform segmentation and recognition
        segmented_image, table_data = segment_and_recognize(image)

        # Display the segmented image
        st.image(segmented_image, caption="Segmented Image with Bounding Boxes", use_column_width=True)

        # Display the table with polygons and recognized text
        st.table(table_data)