|
import streamlit as st |
|
import warnings |
|
warnings.simplefilter("ignore", UserWarning) |
|
|
|
from uuid import uuid4 |
|
from laia.scripts.htr.decode_ctc import run as decode |
|
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs |
|
import sys |
|
from tempfile import NamedTemporaryFile, mkdtemp |
|
from pathlib import Path |
|
from contextlib import redirect_stdout |
|
import re |
|
from PIL import Image |
|
from bidi.algorithm import get_display |
|
import multiprocessing |
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
import pandas as pd |
|
import logging |
|
from typing import List, Optional |
|
|
|
|
|
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) |
|
|
|
|
|
model = YOLO('model.pt') |
|
images = Path(mkdtemp()) |
|
DEFAULT_HEIGHT = 128 |
|
TEXT_DIRECTION = "LTR" |
|
NUM_WORKERS = multiprocessing.cpu_count() |
|
|
|
|
|
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})" |
|
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" |
|
TEXT_PATTERN = r"\s*(?P<text>.*)\s*" |
|
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}") |
|
|
|
def get_width(image, height=DEFAULT_HEIGHT): |
|
aspect_ratio = image.width / image.height |
|
return height * aspect_ratio |
|
|
|
def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]: |
|
"""Simplify polygon contours using Douglas-Peucker algorithm. |
|
|
|
Args: |
|
polygons: List of polygon contours |
|
approx_level: Approximation level (0-1), lower values mean more simplification |
|
|
|
Returns: |
|
List of simplified polygons (or None for invalid polygons) |
|
""" |
|
result = [] |
|
for polygon in polygons: |
|
if len(polygon) < 4: |
|
result.append(None) |
|
continue |
|
|
|
perimeter = cv2.arcLength(polygon, True) |
|
approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True) |
|
if len(approx) < 4: |
|
result.append(None) |
|
continue |
|
|
|
result.append(approx.squeeze()) |
|
return result |
|
|
|
def predict(model_name, input_img): |
|
model_dir = 'catmus-medieval' |
|
temperature = 2.0 |
|
batch_size = 1 |
|
|
|
weights_path = f"{model_dir}/weights.ckpt" |
|
syms_path = f"{model_dir}/syms.txt" |
|
language_model_params = {"language_model_weight": 1.0} |
|
use_language_model = True |
|
if use_language_model: |
|
language_model_params.update({ |
|
"language_model_path": f"{model_dir}/language_model.binary", |
|
"lexicon_path": f"{model_dir}/lexicon.txt", |
|
"tokens_path": f"{model_dir}/tokens.txt", |
|
}) |
|
|
|
common_args = CommonArgs( |
|
checkpoint="weights.ckpt", |
|
train_path=f"{model_dir}", |
|
experiment_dirname="", |
|
) |
|
|
|
data_args = DataArgs(batch_size=batch_size, color_mode="L") |
|
trainer_args = TrainerArgs(progress_bar_refresh_rate=0) |
|
decode_args = DecodeArgs( |
|
include_img_ids=True, |
|
join_string="", |
|
convert_spaces=True, |
|
print_line_confidence_scores=True, |
|
print_word_confidence_scores=False, |
|
temperature=temperature, |
|
use_language_model=use_language_model, |
|
**language_model_params, |
|
) |
|
|
|
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list: |
|
image_id = uuid4() |
|
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT)) |
|
input_img.save(f"{images}/{image_id}.jpg") |
|
Path(img_list.name).write_text("\n".join([str(image_id)])) |
|
|
|
with redirect_stdout(open(pred_stdout.name, mode="w")): |
|
decode( |
|
syms=str(syms_path), |
|
img_list=img_list.name, |
|
img_dirs=[str(images)], |
|
common=common_args, |
|
data=data_args, |
|
trainer=trainer_args, |
|
decode=decode_args, |
|
num_workers=1, |
|
) |
|
sys.stdout.flush() |
|
predictions = Path(pred_stdout.name).read_text().strip().splitlines() |
|
|
|
_, score, text = LINE_PREDICTION.match(predictions[0]).groups() |
|
if TEXT_DIRECTION == "RTL": |
|
return input_img, {"text": get_display(text), "score": score} |
|
else: |
|
return input_img, {"text": text, "score": score} |
|
|
|
def process_image(image): |
|
|
|
results = model(image, classes=0) |
|
|
|
img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
masks = results[0].masks |
|
polygons = [] |
|
texts = [] |
|
|
|
if masks is not None: |
|
|
|
masks = masks.data.cpu().numpy() |
|
img_height, img_width = img_cv2.shape[:2] |
|
|
|
|
|
boxes = results[0].boxes.xyxy.cpu().numpy() |
|
|
|
|
|
sorted_indices = np.argsort(boxes[:, 1]) |
|
masks = masks[sorted_indices] |
|
boxes = boxes[sorted_indices] |
|
|
|
for i, (mask, box) in enumerate(zip(masks, boxes)): |
|
|
|
mask = cv2.resize(mask.squeeze(), (img_width, img_height), interpolation=cv2.INTER_LINEAR) |
|
mask = (mask > 0.5).astype(np.uint8) * 255 |
|
|
|
|
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
if contours: |
|
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
simplified_polygon = simplify_polygons([largest_contour])[0] |
|
|
|
if simplified_polygon is not None: |
|
|
|
x1, y1, x2, y2 = map(int, box) |
|
crop_img = img_cv2[y1:y2, x1:x2] |
|
crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
predicted = predict('pylaia-samaritan_v1', crop_pil) |
|
texts.append(predicted[1]["text"]) |
|
|
|
|
|
poly_points = simplified_polygon.reshape(-1, 2).astype(int).tolist() |
|
polygons.append(f"Line {i+1}: {poly_points}") |
|
|
|
|
|
cv2.polylines(img_cv2, [simplified_polygon.reshape(-1, 1, 2).astype(int)], |
|
True, (0, 255, 0), 2) |
|
|
|
|
|
img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts}) |
|
return Image.fromarray(img_result), table_data |
|
|
|
def segment_and_recognize(image): |
|
segmented_image, table_data = process_image(image) |
|
return segmented_image, table_data |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
st.title("YOLOv11 Text Line Segmentation & PyLaia Text Recognition on CATMuS/medieval") |
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) |
|
|
|
|
|
if uploaded_image is not None: |
|
image = Image.open(uploaded_image) |
|
|
|
if st.button("Segment and Recognize"): |
|
|
|
segmented_image, table_data = segment_and_recognize(image) |
|
|
|
|
|
col1, col2 = st.columns([2, 3]) |
|
|
|
with col1: |
|
st.image(segmented_image, caption="Segmented Image with Polygon Masks", use_container_width=True) |
|
|
|
with col2: |
|
st.table(table_data) |
|
|
|
|