Spaces:
Running
Running
File size: 5,385 Bytes
0ae43ca 4d40b14 0ae43ca 0552a9b 0ae43ca 0552a9b 0ae43ca 0552a9b 0ae43ca 0552a9b 0ae43ca e59fb5f 0ae43ca 4d40b14 0ae43ca 12982b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import streamlit as st
import warnings
warnings.simplefilter("ignore", UserWarning)
from uuid import uuid4
from laia.scripts.htr.decode_ctc import run as decode
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
import sys
from tempfile import NamedTemporaryFile, mkdtemp
from pathlib import Path
from contextlib import redirect_stdout
import re
from PIL import Image
from bidi.algorithm import get_display
import multiprocessing
from ultralytics import YOLO
import cv2
import numpy as np
import pandas as pd
import logging
# Configure logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
# Load YOLOv8 model
model = YOLO('model.pt')
images = Path(mkdtemp())
DEFAULT_HEIGHT = 128
TEXT_DIRECTION = "RTL"
NUM_WORKERS = multiprocessing.cpu_count()
# Regex pattern for extracting results
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
def get_width(image, height=DEFAULT_HEIGHT):
aspect_ratio = image.width / image.height
return height * aspect_ratio
def predict(model_name, input_img):
model_dir = 'pylaia-samaritan_v1'
temperature = 2.0
batch_size = 1
weights_path = f"{model_dir}/weights.ckpt"
syms_path = f"{model_dir}/syms.txt"
language_model_params = {"language_model_weight": 1.0}
use_language_model = True
if use_language_model:
language_model_params.update({
"language_model_path": f"{model_dir}/language_model.binary",
"lexicon_path": f"{model_dir}/lexicon.txt",
"tokens_path": f"{model_dir}/tokens.txt",
})
common_args = CommonArgs(
checkpoint="weights.ckpt",
train_path=f"{model_dir}",
experiment_dirname="",
)
data_args = DataArgs(batch_size=batch_size, color_mode="L")
trainer_args = TrainerArgs(progress_bar_refresh_rate=0)
decode_args = DecodeArgs(
include_img_ids=True,
join_string="",
convert_spaces=True,
print_line_confidence_scores=True,
print_word_confidence_scores=False,
temperature=temperature,
use_language_model=use_language_model,
**language_model_params,
)
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
image_id = uuid4()
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
input_img.save(f"{images}/{image_id}.jpg")
Path(img_list.name).write_text("\n".join([str(image_id)]))
with redirect_stdout(open(pred_stdout.name, mode="w")):
decode(
syms=str(syms_path),
img_list=img_list.name,
img_dirs=[str(images)],
common=common_args,
data=data_args,
trainer=trainer_args,
decode=decode_args,
num_workers=1,
)
sys.stdout.flush()
predictions = Path(pred_stdout.name).read_text().strip().splitlines()
_, score, text = LINE_PREDICTION.match(predictions[0]).groups()
if TEXT_DIRECTION == "RTL":
return input_img, {"text": get_display(text), "score": score}
else:
return input_img, {"text": text, "score": score}
def process_image(image):
# Perform inference on an image, select textline only
results = model(image, classes=0)
img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
boxes = results[0].boxes.xyxy.tolist()
boxes.sort(key=lambda x: x[1])
bboxes = []
polygons = []
texts = []
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box)
crop_img = img_cv2[y1:y2, x1:x2]
crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
# Recognize text using PyLaia model
predicted = predict('pylaia-samaritan_v1', crop_pil)
texts.append(predicted[1]["text"])
bboxes.append((x1, y1, x2, y2))
polygons.append(f"Line {i+1}: {[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]}")
# Draw bounding box
cv2.rectangle(img_cv2, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Convert image back to RGB for display in Streamlit
img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
# Combine polygons and texts into a DataFrame for table display
table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts})
return Image.fromarray(img_result), table_data
def segment_and_recognize(image):
segmented_image, table_data = process_image(image)
return segmented_image, table_data
# Streamlit app layout
st.title("YOLOv8 Text Line Segmentation & PyLaia Text Recognition")
# File uploader
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
# Process the image if uploaded
if uploaded_image is not None:
image = Image.open(uploaded_image)
if st.button("Segment and Recognize"):
# Perform segmentation and recognition
segmented_image, table_data = segment_and_recognize(image)
# Display the segmented image
st.image(segmented_image, caption="Segmented Image with Bounding Boxes", use_column_width=True)
# Display the table with polygons and recognized text
st.table(table_data) |