Spaces:
Sleeping
Sleeping
Commit
·
6c8fcdd
1
Parent(s):
8cb1f4a
Add application file
Browse files
app.py
CHANGED
|
@@ -1,3 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
import io
|
| 3 |
import os
|
|
@@ -22,18 +46,68 @@ from transformers import (
|
|
| 22 |
from matplotlib import cm
|
| 23 |
import gradio as gr
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
MAX_LEN: int = 128
|
|
|
|
|
|
|
| 26 |
TOPK: int = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
MAX_LINES: int = 120
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
RESIZE_MAX_SIDE: int = 800
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
REL_PROB_TH: float = 0.70
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")
|
|
|
|
|
|
|
| 31 |
FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
@lru_cache(maxsize=1)
|
| 35 |
def load_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
model_path = "AnjaliSarawgi/model-oct"
|
|
|
|
|
|
|
| 37 |
hf_token = os.environ.get("HF_TOKEN")
|
| 38 |
model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
|
| 39 |
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
|
|
@@ -43,13 +117,45 @@ def load_model():
|
|
| 43 |
return model, tokenizer, processor.feature_extractor, device
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def clean_text(text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
text = unicodedata.normalize("NFC", text)
|
| 48 |
text = CLEANUP.sub("", text)
|
| 49 |
return re.sub(r"\s+", "", text)
|
| 50 |
|
| 51 |
|
| 52 |
def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
img = image.convert("RGB")
|
| 54 |
w, h = img.size
|
| 55 |
if max(w, h) > max_side:
|
|
@@ -58,10 +164,36 @@ def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.
|
|
| 58 |
|
| 59 |
|
| 60 |
def get_amp_ctx():
|
|
|
|
| 61 |
return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
#
|
| 64 |
def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def _strip_ns(elem):
|
| 66 |
for e in elem.iter():
|
| 67 |
if isinstance(e.tag, str) and e.tag.startswith("{"):
|
|
@@ -191,6 +323,7 @@ def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tupl
|
|
| 191 |
|
| 192 |
|
| 193 |
def sort_boxes_reading_order(boxes, y_tol: int = 10):
|
|
|
|
| 194 |
def key(b):
|
| 195 |
x1, y1, x2, y2 = b["bbox"]
|
| 196 |
return (round(y1 / max(1, y_tol)), y1, x1)
|
|
@@ -198,6 +331,21 @@ def sort_boxes_reading_order(boxes, y_tol: int = 10):
|
|
| 198 |
|
| 199 |
|
| 200 |
def draw_boxes(img: Image.Image, boxes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
base = img.convert("RGBA")
|
| 202 |
overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
|
| 203 |
draw = ImageDraw.Draw(overlay)
|
|
@@ -218,7 +366,33 @@ def draw_boxes(img: Image.Image, boxes):
|
|
| 218 |
return Image.alpha_composite(base, overlay).convert("RGB")
|
| 219 |
|
| 220 |
|
|
|
|
|
|
|
|
|
|
| 221 |
def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
model, tokenizer, feature_extractor, device = load_model()
|
| 223 |
img = prepare_image(image)
|
| 224 |
pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
|
|
@@ -305,6 +479,13 @@ def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOP
|
|
| 305 |
return decoded_text, df
|
| 306 |
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range
|
| 309 |
INDEP_VOW = "\u0904-\u0914" # independent vowels
|
| 310 |
NUKTA = "\u093C" # nukta
|
|
@@ -352,23 +533,65 @@ def parse_alt_tokens(s: str):
|
|
| 352 |
def highlight_tokens_with_tooltips(
|
| 353 |
line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
|
| 354 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
aks, spans = split_aksharas(line_text)
|
| 356 |
joined = "".join(aks)
|
| 357 |
-
used_ranges = []
|
| 358 |
-
insertions = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
for _, row in df_tok.iterrows():
|
| 360 |
-
token = row.get("token", "").strip()
|
|
|
|
|
|
|
|
|
|
| 361 |
try:
|
| 362 |
val = float(row.get(metric_column, 0))
|
| 363 |
except Exception:
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
continue
|
| 367 |
-
# Try finding the token in the joined akshara sequence
|
| 368 |
start_char_idx = joined.find(token)
|
| 369 |
if start_char_idx == -1:
|
| 370 |
continue
|
| 371 |
-
# Locate
|
| 372 |
ak_start = ak_end = None
|
| 373 |
cum_len = 0
|
| 374 |
for i, ak in enumerate(aks):
|
|
@@ -381,17 +604,16 @@ def highlight_tokens_with_tooltips(
|
|
| 381 |
cum_len = next_len
|
| 382 |
if ak_start is None or ak_end is None:
|
| 383 |
continue
|
| 384 |
-
#
|
| 385 |
if any(r[0] < ak_end and ak_start < r[1] for r in used_ranges):
|
| 386 |
continue
|
| 387 |
used_ranges.append((ak_start, ak_end))
|
| 388 |
-
# Character positions
|
| 389 |
char_start = spans[ak_start][0]
|
| 390 |
char_end = spans[ak_end - 1][1]
|
| 391 |
-
#
|
| 392 |
alt_toks = row.get("alt_tokens", "").split("|")
|
| 393 |
alt_probs = row.get("alt_probs", "").split("|")
|
| 394 |
-
tooltip_lines = []
|
| 395 |
for t, p in zip(alt_toks, alt_probs):
|
| 396 |
try:
|
| 397 |
prob = float(p)
|
|
@@ -400,12 +622,16 @@ def highlight_tokens_with_tooltips(
|
|
| 400 |
tooltip_lines.append(f"{_html_escape(t)}: {prob:.3f}")
|
| 401 |
tooltip = "\n".join(tooltip_lines)
|
| 402 |
token_str = _html_escape(line_text[char_start:char_end])
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
| 404 |
insertions.append((char_start, char_end, html_token))
|
|
|
|
| 405 |
if not insertions:
|
| 406 |
return _html_escape(line_text)
|
| 407 |
insertions.sort()
|
| 408 |
-
out_parts = []
|
| 409 |
last_idx = 0
|
| 410 |
for s, e, html_tok in insertions:
|
| 411 |
out_parts.append(_html_escape(line_text[last_idx:s]))
|
|
@@ -435,7 +661,40 @@ def run_ocr(
|
|
| 435 |
apply_bin: bool,
|
| 436 |
highlight_metric: str,
|
| 437 |
):
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
if image is None:
|
| 440 |
return None, "", None, None, None
|
| 441 |
# Convert the numpy array to a PIL image
|
|
@@ -548,6 +807,32 @@ def run_ocr(
|
|
| 548 |
if group.iloc[-1]["token"].strip() == "":
|
| 549 |
to_drop.append(group.index[-1])
|
| 550 |
df_all = df_all.drop(index=to_drop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
# Prepare plain text by stripping HTML tags and replacing <br>
|
| 552 |
plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n"))
|
| 553 |
# Write temporary files
|
|
@@ -569,10 +854,7 @@ def run_ocr(
|
|
| 569 |
csv_fd.close()
|
| 570 |
except Exception:
|
| 571 |
csv_path = None
|
| 572 |
-
|
| 573 |
-
txt_bytes = plain_text.encode("utf-8")
|
| 574 |
-
csv_bytes = df_all.to_csv(index=False).encode("utf-8")
|
| 575 |
-
return overlay_img, predicted_html, df_all, txt_bytes, csv_bytes
|
| 576 |
|
| 577 |
|
| 578 |
# ----------------------------------------------------------------------
|
|
@@ -581,7 +863,7 @@ def run_ocr(
|
|
| 581 |
def create_gradio_interface():
|
| 582 |
"""Create and return the Gradio Blocks interface."""
|
| 583 |
with gr.Blocks(title="Old Nepali HTR") as demo:
|
| 584 |
-
gr.Markdown("""# Old Nepali HTR \n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
|
| 585 |
with gr.Row():
|
| 586 |
image_input = gr.Image(type="numpy", label="Upload Image")
|
| 587 |
# When used as an input, gr.File returns either a file path or bytes
|
|
@@ -607,11 +889,8 @@ def create_gradio_interface():
|
|
| 607 |
predictions_output = gr.HTML(label="Predictions (HTML)")
|
| 608 |
df_output = gr.DataFrame(label="Token Scores", interactive=False)
|
| 609 |
# Separate file outputs for the OCR prediction, token scores and edited text.
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)", type="binary")
|
| 613 |
-
ocr_csv_output = gr.File(label="Download Token Scores (.csv)", type="binary")
|
| 614 |
-
|
| 615 |
edited_txt_output = gr.File(label="Download edited text (.txt)")
|
| 616 |
|
| 617 |
# Editable text area
|
|
@@ -662,7 +941,5 @@ def create_gradio_interface():
|
|
| 662 |
)
|
| 663 |
return demo
|
| 664 |
|
| 665 |
-
|
| 666 |
-
|
| 667 |
iface = create_gradio_interface()
|
| 668 |
iface.launch()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio application for performing OCR on scanned Old Nepali documents.
|
| 3 |
+
|
| 4 |
+
This script is a Gradio port of a Streamlit application originally built
|
| 5 |
+
to visualize and edit OCR output. It loads a pre‑trained model for
|
| 6 |
+
sequence decoding, accepts an input image (and optional segmentation
|
| 7 |
+
XML in ALTO format), performs OCR on segmented lines, highlights tokens
|
| 8 |
+
with low confidence and offers downloads of both the raw text and per
|
| 9 |
+
token scores.
|
| 10 |
+
|
| 11 |
+
The heavy lifting functions (model loading, pre‑processing, inference
|
| 12 |
+
and highlighting) are adapted directly from the Streamlit version. The
|
| 13 |
+
UI has been simplified for Gradio: users upload an image and optional
|
| 14 |
+
XML file, choose preprocessing steps and a highlight metric, then run
|
| 15 |
+
OCR. The results are displayed alongside the overlaid segmentation
|
| 16 |
+
boxes and a table of token scores. An editable textbox lets users
|
| 17 |
+
modify the predicted text before downloading it.
|
| 18 |
+
|
| 19 |
+
To run this app locally, install gradio (`pip install gradio`) and
|
| 20 |
+
execute this script with Python:
|
| 21 |
+
|
| 22 |
+
python gradio_app.py
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
|
| 26 |
import io
|
| 27 |
import os
|
|
|
|
| 46 |
from matplotlib import cm
|
| 47 |
import gradio as gr
|
| 48 |
|
| 49 |
+
# ----------------------------------------------------------------------
|
| 50 |
+
# Configuration
|
| 51 |
+
#
|
| 52 |
+
# These constants control various aspects of the OCR pipeline. You can
|
| 53 |
+
# adjust them to trade off accuracy, performance or output volume.
|
| 54 |
+
|
| 55 |
+
# The maximum number of tokens to decode for a single line. If your
|
| 56 |
+
# documents typically have longer lines you can increase this value, but
|
| 57 |
+
# beware that very long sequences may cause more memory usage.
|
| 58 |
MAX_LEN: int = 128
|
| 59 |
+
|
| 60 |
+
# How many alternative tokens to keep when computing per–token statistics.
|
| 61 |
TOPK: int = 3
|
| 62 |
+
|
| 63 |
+
# If an XML segmentation file is provided, only process the first
|
| 64 |
+
# MAX_LINES lines. This prevents huge documents from consuming
|
| 65 |
+
# excessive resources.
|
| 66 |
MAX_LINES: int = 120
|
| 67 |
+
|
| 68 |
+
# Images are resized such that the longest side does not exceed this
|
| 69 |
+
# number of pixels before passing them to the OCR model. Increasing
|
| 70 |
+
# this value may improve accuracy at the cost of speed and memory.
|
| 71 |
RESIZE_MAX_SIDE: int = 800
|
| 72 |
+
|
| 73 |
+
# Threshold used when highlighting tokens by relative probability. A
|
| 74 |
+
# ratio of Top2/Top1 greater than this value will cause the token to
|
| 75 |
+
# be highlighted in red.
|
| 76 |
REL_PROB_TH: float = 0.70
|
| 77 |
+
|
| 78 |
+
# A regex used to clean up Unicode control characters before text
|
| 79 |
+
# normalization. Soft hyphens, zero width spaces and similar marks
|
| 80 |
+
# interfere with accurate token matching.
|
| 81 |
CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")
|
| 82 |
+
|
| 83 |
+
# Default font path for rendering predictions directly on the image.
|
| 84 |
FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
|
| 85 |
|
| 86 |
|
| 87 |
+
# ----------------------------------------------------------------------
|
| 88 |
+
# Model loading
|
| 89 |
+
#
|
| 90 |
+
# Loading the model and associated tokenizer/processor is slow. Use
|
| 91 |
+
# functools.lru_cache to ensure this only happens once per process.
|
| 92 |
+
|
| 93 |
@lru_cache(maxsize=1)
|
| 94 |
def load_model():
|
| 95 |
+
"""Load the OCR model, tokenizer and feature extractor.
|
| 96 |
+
|
| 97 |
+
Returns
|
| 98 |
+
-------
|
| 99 |
+
model : VisionEncoderDecoderModel
|
| 100 |
+
The loaded model in evaluation mode.
|
| 101 |
+
tokenizer : PreTrainedTokenizerFast
|
| 102 |
+
Tokenizer corresponding to the decoder part of the model.
|
| 103 |
+
feature_extractor : callable
|
| 104 |
+
Feature extractor converting PIL images into model inputs.
|
| 105 |
+
device : torch.device
|
| 106 |
+
The device (CPU or CUDA) used for inference.
|
| 107 |
+
"""
|
| 108 |
model_path = "AnjaliSarawgi/model-oct"
|
| 109 |
+
# In an offline environment the HF token is None; if you wish
|
| 110 |
+
# to use a private model you can set HF_TOKEN in your environment.
|
| 111 |
hf_token = os.environ.get("HF_TOKEN")
|
| 112 |
model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
|
| 113 |
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
|
|
|
|
| 117 |
return model, tokenizer, processor.feature_extractor, device
|
| 118 |
|
| 119 |
|
| 120 |
+
# ----------------------------------------------------------------------
|
| 121 |
+
# Utility functions
|
| 122 |
+
#
|
| 123 |
+
|
| 124 |
def clean_text(text: str) -> str:
|
| 125 |
+
"""Normalize and collapse whitespace from a decoded string.
|
| 126 |
+
|
| 127 |
+
Parameters
|
| 128 |
+
----------
|
| 129 |
+
text : str
|
| 130 |
+
The raw decoded string from the model.
|
| 131 |
+
|
| 132 |
+
Returns
|
| 133 |
+
-------
|
| 134 |
+
str
|
| 135 |
+
The cleaned string with Unicode normalization and whitespace
|
| 136 |
+
removed. All whitespace characters are stripped since the
|
| 137 |
+
predictions are later tokenized at the akshara (syllable) level.
|
| 138 |
+
"""
|
| 139 |
text = unicodedata.normalize("NFC", text)
|
| 140 |
text = CLEANUP.sub("", text)
|
| 141 |
return re.sub(r"\s+", "", text)
|
| 142 |
|
| 143 |
|
| 144 |
def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
|
| 145 |
+
"""Resize the image so that its longest side equals max_side.
|
| 146 |
+
|
| 147 |
+
Parameters
|
| 148 |
+
----------
|
| 149 |
+
image : PIL.Image
|
| 150 |
+
Input image.
|
| 151 |
+
max_side : int, optional
|
| 152 |
+
Maximum allowed size for the longest side of the image.
|
| 153 |
+
|
| 154 |
+
Returns
|
| 155 |
+
-------
|
| 156 |
+
PIL.Image
|
| 157 |
+
The resized image.
|
| 158 |
+
"""
|
| 159 |
img = image.convert("RGB")
|
| 160 |
w, h = img.size
|
| 161 |
if max(w, h) > max_side:
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
def get_amp_ctx():
|
| 167 |
+
"""Return the appropriate context manager for automatic mixed precision."""
|
| 168 |
return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext
|
| 169 |
|
| 170 |
+
|
| 171 |
+
# ----------------------------------------------------------------------
|
| 172 |
+
# XML parsing and segmentation
|
| 173 |
#
|
| 174 |
def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
|
| 175 |
+
"""Parse ALTO or PAGE XML to extract bounding boxes.
|
| 176 |
+
|
| 177 |
+
Parameters
|
| 178 |
+
----------
|
| 179 |
+
xml_bytes : bytes
|
| 180 |
+
Raw XML bytes.
|
| 181 |
+
level : {"block", "line", "word"}, optional
|
| 182 |
+
The segmentation level to extract. For OCR we use "line".
|
| 183 |
+
image_size : tuple or None
|
| 184 |
+
If provided, image_size=(width, height) allows rescaling
|
| 185 |
+
coordinates to match the actual image. ALTO files often store
|
| 186 |
+
absolute page sizes that differ from the image dimensions.
|
| 187 |
+
|
| 188 |
+
Returns
|
| 189 |
+
-------
|
| 190 |
+
list of dict
|
| 191 |
+
Each dict represents a bounding box with keys:
|
| 192 |
+
- "bbox": [x1, y1, x2, y2]
|
| 193 |
+
- "points": list of (x, y) if polygonal coordinates exist
|
| 194 |
+
- "id": line identifier (string)
|
| 195 |
+
- "label": the type of element (e.g. TextLine)
|
| 196 |
+
"""
|
| 197 |
def _strip_ns(elem):
|
| 198 |
for e in elem.iter():
|
| 199 |
if isinstance(e.tag, str) and e.tag.startswith("{"):
|
|
|
|
| 323 |
|
| 324 |
|
| 325 |
def sort_boxes_reading_order(boxes, y_tol: int = 10):
|
| 326 |
+
"""Sort bounding boxes top‑to‑bottom then left‑to‑right."""
|
| 327 |
def key(b):
|
| 328 |
x1, y1, x2, y2 = b["bbox"]
|
| 329 |
return (round(y1 / max(1, y_tol)), y1, x1)
|
|
|
|
| 331 |
|
| 332 |
|
| 333 |
def draw_boxes(img: Image.Image, boxes):
|
| 334 |
+
"""Overlay semi‑transparent red polygons or rectangles on an image.
|
| 335 |
+
|
| 336 |
+
Parameters
|
| 337 |
+
----------
|
| 338 |
+
img : PIL.Image
|
| 339 |
+
The base image.
|
| 340 |
+
boxes : list of dict
|
| 341 |
+
Segmentation boxes with either 'points' or 'bbox' keys.
|
| 342 |
+
|
| 343 |
+
Returns
|
| 344 |
+
-------
|
| 345 |
+
PIL.Image
|
| 346 |
+
An image with red overlays marking each box. Boxes are numbered
|
| 347 |
+
starting from 1.
|
| 348 |
+
"""
|
| 349 |
base = img.convert("RGBA")
|
| 350 |
overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
|
| 351 |
draw = ImageDraw.Draw(overlay)
|
|
|
|
| 366 |
return Image.alpha_composite(base, overlay).convert("RGB")
|
| 367 |
|
| 368 |
|
| 369 |
+
# ----------------------------------------------------------------------
|
| 370 |
+
# OCR inference per line
|
| 371 |
+
#
|
| 372 |
def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
|
| 373 |
+
"""Run the model on a single cropped line and return predictions and scores.
|
| 374 |
+
|
| 375 |
+
This helper wraps the model.generate call to obtain per‑token
|
| 376 |
+
probabilities and derives a DataFrame summarizing each decoding step.
|
| 377 |
+
|
| 378 |
+
Parameters
|
| 379 |
+
----------
|
| 380 |
+
image : PIL.Image
|
| 381 |
+
Cropped segment to process.
|
| 382 |
+
line_id : int, optional
|
| 383 |
+
Identifier used in the output DataFrame.
|
| 384 |
+
topk : int, optional
|
| 385 |
+
Number of alternative tokens to keep for each decoding position.
|
| 386 |
+
|
| 387 |
+
Returns
|
| 388 |
+
-------
|
| 389 |
+
decoded_text : str
|
| 390 |
+
Cleaned predicted string for the line.
|
| 391 |
+
df : pandas.DataFrame
|
| 392 |
+
Table with one row per generated token containing the following
|
| 393 |
+
columns: line_id, seq_pos, token_id, token, confidence,
|
| 394 |
+
rel_prob, entropy, gap12, alt_tokens, alt_probs.
|
| 395 |
+
"""
|
| 396 |
model, tokenizer, feature_extractor, device = load_model()
|
| 397 |
img = prepare_image(image)
|
| 398 |
pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
|
|
|
|
| 479 |
return decoded_text, df
|
| 480 |
|
| 481 |
|
| 482 |
+
# ----------------------------------------------------------------------
|
| 483 |
+
# Text splitting into aksharas (syllable units) for highlighting
|
| 484 |
+
#
|
| 485 |
+
# The following regex and helper functions split a Devanagari string into
|
| 486 |
+
# aksharas. This is necessary to map model tokens back to spans of
|
| 487 |
+
# characters when highlighting uncertain predictions.
|
| 488 |
+
|
| 489 |
DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range
|
| 490 |
INDEP_VOW = "\u0904-\u0914" # independent vowels
|
| 491 |
NUKTA = "\u093C" # nukta
|
|
|
|
| 533 |
def highlight_tokens_with_tooltips(
|
| 534 |
line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
|
| 535 |
) -> str:
|
| 536 |
+
"""Insert HTML spans around tokens whose chosen metric exceeds threshold.
|
| 537 |
+
|
| 538 |
+
The metric column can be "rel_prob" (relative probability) or
|
| 539 |
+
"entropy". Tokens with a value strictly greater than red_threshold
|
| 540 |
+
will be wrapped in a span with a tooltip listing alternative
|
| 541 |
+
predictions and their probabilities.
|
| 542 |
+
|
| 543 |
+
Parameters
|
| 544 |
+
----------
|
| 545 |
+
line_text : str
|
| 546 |
+
The cleaned line prediction.
|
| 547 |
+
df_tok : pandas.DataFrame
|
| 548 |
+
DataFrame of token statistics for the corresponding line.
|
| 549 |
+
red_threshold : float
|
| 550 |
+
Values above this threshold will be highlighted.
|
| 551 |
+
metric_column : str
|
| 552 |
+
Column name in df_tok used for thresholding.
|
| 553 |
+
|
| 554 |
+
Returns
|
| 555 |
+
-------
|
| 556 |
+
str
|
| 557 |
+
An HTML string with <span> elements inserted.
|
| 558 |
+
"""
|
| 559 |
aks, spans = split_aksharas(line_text)
|
| 560 |
joined = "".join(aks)
|
| 561 |
+
used_ranges: list = []
|
| 562 |
+
insertions: list = []
|
| 563 |
+
# Define colour classification depending on the metric
|
| 564 |
+
def color_class(val: float) -> str:
|
| 565 |
+
if metric_column == "rel_prob":
|
| 566 |
+
# Use the same thresholds as the original app: >0.7 red, >=0.05 yellow, otherwise green
|
| 567 |
+
if val >= 0.70:
|
| 568 |
+
return "token-red"
|
| 569 |
+
elif val >= 0.05:
|
| 570 |
+
return "token-yellow"
|
| 571 |
+
else:
|
| 572 |
+
return "token-green"
|
| 573 |
+
else:
|
| 574 |
+
# For entropy, high values indicate uncertainty. Thresholds here are heuristics.
|
| 575 |
+
if val >= 2.0:
|
| 576 |
+
return "token-red"
|
| 577 |
+
elif val >= 1.0:
|
| 578 |
+
return "token-yellow"
|
| 579 |
+
else:
|
| 580 |
+
return "token-green"
|
| 581 |
for _, row in df_tok.iterrows():
|
| 582 |
+
token = str(row.get("token", "")).strip()
|
| 583 |
+
if not token:
|
| 584 |
+
continue
|
| 585 |
+
# Extract metric value for classification
|
| 586 |
try:
|
| 587 |
val = float(row.get(metric_column, 0))
|
| 588 |
except Exception:
|
| 589 |
+
val = 0.0
|
| 590 |
+
# Find the first occurrence of the token in the joined akshara sequence
|
|
|
|
|
|
|
| 591 |
start_char_idx = joined.find(token)
|
| 592 |
if start_char_idx == -1:
|
| 593 |
continue
|
| 594 |
+
# Locate corresponding akshara boundaries
|
| 595 |
ak_start = ak_end = None
|
| 596 |
cum_len = 0
|
| 597 |
for i, ak in enumerate(aks):
|
|
|
|
| 604 |
cum_len = next_len
|
| 605 |
if ak_start is None or ak_end is None:
|
| 606 |
continue
|
| 607 |
+
# Prevent overlapping spans
|
| 608 |
if any(r[0] < ak_end and ak_start < r[1] for r in used_ranges):
|
| 609 |
continue
|
| 610 |
used_ranges.append((ak_start, ak_end))
|
|
|
|
| 611 |
char_start = spans[ak_start][0]
|
| 612 |
char_end = spans[ak_end - 1][1]
|
| 613 |
+
# Prepare tooltip content
|
| 614 |
alt_toks = row.get("alt_tokens", "").split("|")
|
| 615 |
alt_probs = row.get("alt_probs", "").split("|")
|
| 616 |
+
tooltip_lines: list = []
|
| 617 |
for t, p in zip(alt_toks, alt_probs):
|
| 618 |
try:
|
| 619 |
prob = float(p)
|
|
|
|
| 622 |
tooltip_lines.append(f"{_html_escape(t)}: {prob:.3f}")
|
| 623 |
tooltip = "\n".join(tooltip_lines)
|
| 624 |
token_str = _html_escape(line_text[char_start:char_end])
|
| 625 |
+
cls = color_class(val)
|
| 626 |
+
html_token = (
|
| 627 |
+
f"<span class='ocr-token {cls}' data-tooltip='{_html_escape(tooltip)}'>{token_str}</span>"
|
| 628 |
+
)
|
| 629 |
insertions.append((char_start, char_end, html_token))
|
| 630 |
+
# If nothing was highlighted, return escaped original text
|
| 631 |
if not insertions:
|
| 632 |
return _html_escape(line_text)
|
| 633 |
insertions.sort()
|
| 634 |
+
out_parts: list = []
|
| 635 |
last_idx = 0
|
| 636 |
for s, e, html_tok in insertions:
|
| 637 |
out_parts.append(_html_escape(line_text[last_idx:s]))
|
|
|
|
| 661 |
apply_bin: bool,
|
| 662 |
highlight_metric: str,
|
| 663 |
):
|
| 664 |
+
"""Run the OCR pipeline on user inputs and return results for Gradio.
|
| 665 |
+
|
| 666 |
+
Parameters
|
| 667 |
+
----------
|
| 668 |
+
image : numpy.ndarray or None
|
| 669 |
+
The uploaded image converted to a NumPy array by Gradio. If
|
| 670 |
+
None, the function returns empty results.
|
| 671 |
+
xml_file : tuple or None
|
| 672 |
+
A tuple representing the uploaded XML file as provided by
|
| 673 |
+
gr.File. The first element is the file name and the second is
|
| 674 |
+
bytes. If None, no segmentation is applied and the entire
|
| 675 |
+
image is processed as a single line.
|
| 676 |
+
apply_gray : bool
|
| 677 |
+
Whether to convert the image to grayscale before OCR.
|
| 678 |
+
apply_bin : bool
|
| 679 |
+
Whether to apply binarization (Otsu threshold) before OCR. If
|
| 680 |
+
selected, grayscale conversion is applied first automatically.
|
| 681 |
+
highlight_metric : str
|
| 682 |
+
Which metric to use for highlighting ("Relative Probability" or
|
| 683 |
+
"Entropy").
|
| 684 |
+
|
| 685 |
+
Returns
|
| 686 |
+
-------
|
| 687 |
+
overlay_img : PIL.Image or None
|
| 688 |
+
Image with segmentation boxes drawn. None if no input image.
|
| 689 |
+
predictions_html : str
|
| 690 |
+
HTML formatted predicted text with highlighted tokens.
|
| 691 |
+
df_scores : pandas.DataFrame or None
|
| 692 |
+
DataFrame of per‑token statistics. None if no input image.
|
| 693 |
+
txt_file_path : str or None
|
| 694 |
+
Path to a temporary .txt file containing the plain predicted text.
|
| 695 |
+
csv_file_path : str or None
|
| 696 |
+
Path to a temporary CSV file containing the extended token scores.
|
| 697 |
+
"""
|
| 698 |
if image is None:
|
| 699 |
return None, "", None, None, None
|
| 700 |
# Convert the numpy array to a PIL image
|
|
|
|
| 807 |
if group.iloc[-1]["token"].strip() == "":
|
| 808 |
to_drop.append(group.index[-1])
|
| 809 |
df_all = df_all.drop(index=to_drop)
|
| 810 |
+
# Inject style definitions for token colouring and tooltips
|
| 811 |
+
style_block = """
|
| 812 |
+
<style>
|
| 813 |
+
.ocr-token { position: relative; cursor: pointer; padding: 0 2px; border-radius: 2px; }
|
| 814 |
+
.ocr-token.token-red { background-color: rgba(255, 107, 107, 0.7); }
|
| 815 |
+
.ocr-token.token-yellow { background-color: rgba(255, 217, 59, 0.7); }
|
| 816 |
+
.ocr-token.token-green { background-color: rgba(107, 207, 99, 0.7); }
|
| 817 |
+
.ocr-token:hover::after {
|
| 818 |
+
content: attr(data-tooltip);
|
| 819 |
+
position: absolute;
|
| 820 |
+
bottom: 120%;
|
| 821 |
+
left: 0;
|
| 822 |
+
white-space: pre-line;
|
| 823 |
+
background: #333;
|
| 824 |
+
color: #fff;
|
| 825 |
+
padding: 6px 10px;
|
| 826 |
+
border-radius: 6px;
|
| 827 |
+
font-size: 12px;
|
| 828 |
+
z-index: 999;
|
| 829 |
+
max-width: 220px;
|
| 830 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
|
| 831 |
+
}
|
| 832 |
+
</style>
|
| 833 |
+
"""
|
| 834 |
+
if predicted_html:
|
| 835 |
+
predicted_html = style_block + predicted_html
|
| 836 |
# Prepare plain text by stripping HTML tags and replacing <br>
|
| 837 |
plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n"))
|
| 838 |
# Write temporary files
|
|
|
|
| 854 |
csv_fd.close()
|
| 855 |
except Exception:
|
| 856 |
csv_path = None
|
| 857 |
+
return overlay_img, predicted_html, df_all, txt_path, csv_path
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
|
| 860 |
# ----------------------------------------------------------------------
|
|
|
|
| 863 |
def create_gradio_interface():
|
| 864 |
"""Create and return the Gradio Blocks interface."""
|
| 865 |
with gr.Blocks(title="Old Nepali HTR") as demo:
|
| 866 |
+
gr.Markdown("""# Old Nepali HTR (Gradio)\n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
|
| 867 |
with gr.Row():
|
| 868 |
image_input = gr.Image(type="numpy", label="Upload Image")
|
| 869 |
# When used as an input, gr.File returns either a file path or bytes
|
|
|
|
| 889 |
predictions_output = gr.HTML(label="Predictions (HTML)")
|
| 890 |
df_output = gr.DataFrame(label="Token Scores", interactive=False)
|
| 891 |
# Separate file outputs for the OCR prediction, token scores and edited text.
|
| 892 |
+
ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)")
|
| 893 |
+
ocr_csv_output = gr.File(label="Download Token Scores (.csv)")
|
|
|
|
|
|
|
|
|
|
| 894 |
edited_txt_output = gr.File(label="Download edited text (.txt)")
|
| 895 |
|
| 896 |
# Editable text area
|
|
|
|
| 941 |
)
|
| 942 |
return demo
|
| 943 |
|
|
|
|
|
|
|
| 944 |
iface = create_gradio_interface()
|
| 945 |
iface.launch()
|