AnjaliSarawgi commited on
Commit
185acd0
·
1 Parent(s): 719cede

Add application file

Browse files
Files changed (1) hide show
  1. app.py +899 -0
app.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio application for performing OCR on scanned Old Nepali documents.
3
+
4
+ This script is a Gradio port of a Streamlit application originally built
5
+ to visualize and edit OCR output. It loads a pre‑trained model for
6
+ sequence decoding, accepts an input image (and optional segmentation
7
+ XML in ALTO format), performs OCR on segmented lines, highlights tokens
8
+ with low confidence and offers downloads of both the raw text and per
9
+ token scores.
10
+
11
+ The heavy lifting functions (model loading, pre‑processing, inference
12
+ and highlighting) are adapted directly from the Streamlit version. The
13
+ UI has been simplified for Gradio: users upload an image and optional
14
+ XML file, choose preprocessing steps and a highlight metric, then run
15
+ OCR. The results are displayed alongside the overlaid segmentation
16
+ boxes and a table of token scores. An editable textbox lets users
17
+ modify the predicted text before downloading it.
18
+
19
+ To run this app locally, install gradio (`pip install gradio`) and
20
+ execute this script with Python:
21
+
22
+ python gradio_app.py
23
+
24
+ """
25
+
26
+ import io
27
+ import os
28
+ import re
29
+ import base64
30
+ import unicodedata
31
+ import contextlib
32
+ import xml.etree.ElementTree as ET
33
+ from collections import defaultdict
34
+ from functools import lru_cache
35
+
36
+ import numpy as np
37
+ import pandas as pd
38
+ from PIL import Image, ImageDraw, ImageFont
39
+ import cv2
40
+ import torch
41
+ from transformers import (
42
+ VisionEncoderDecoderModel,
43
+ PreTrainedTokenizerFast,
44
+ TrOCRProcessor,
45
+ )
46
+ from matplotlib import cm
47
+ import gradio as gr
48
+
49
+ # ----------------------------------------------------------------------
50
+ # Configuration
51
+ #
52
+ # These constants control various aspects of the OCR pipeline. You can
53
+ # adjust them to trade off accuracy, performance or output volume.
54
+
55
+ # The maximum number of tokens to decode for a single line. If your
56
+ # documents typically have longer lines you can increase this value, but
57
+ # beware that very long sequences may cause more memory usage.
58
+ MAX_LEN: int = 128
59
+
60
+ # How many alternative tokens to keep when computing per–token statistics.
61
+ TOPK: int = 3
62
+
63
+ # If an XML segmentation file is provided, only process the first
64
+ # MAX_LINES lines. This prevents huge documents from consuming
65
+ # excessive resources.
66
+ MAX_LINES: int = 120
67
+
68
+ # Images are resized such that the longest side does not exceed this
69
+ # number of pixels before passing them to the OCR model. Increasing
70
+ # this value may improve accuracy at the cost of speed and memory.
71
+ RESIZE_MAX_SIDE: int = 800
72
+
73
+ # Threshold used when highlighting tokens by relative probability. A
74
+ # ratio of Top2/Top1 greater than this value will cause the token to
75
+ # be highlighted in red.
76
+ REL_PROB_TH: float = 0.70
77
+
78
+ # A regex used to clean up Unicode control characters before text
79
+ # normalization. Soft hyphens, zero width spaces and similar marks
80
+ # interfere with accurate token matching.
81
+ CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")
82
+
83
+ # Default font path for rendering predictions directly on the image.
84
+ FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
85
+
86
+
87
+ # ----------------------------------------------------------------------
88
+ # Model loading
89
+ #
90
+ # Loading the model and associated tokenizer/processor is slow. Use
91
+ # functools.lru_cache to ensure this only happens once per process.
92
+
93
+ @lru_cache(maxsize=1)
94
+ def load_model():
95
+ """Load the OCR model, tokenizer and feature extractor.
96
+
97
+ Returns
98
+ -------
99
+ model : VisionEncoderDecoderModel
100
+ The loaded model in evaluation mode.
101
+ tokenizer : PreTrainedTokenizerFast
102
+ Tokenizer corresponding to the decoder part of the model.
103
+ feature_extractor : callable
104
+ Feature extractor converting PIL images into model inputs.
105
+ device : torch.device
106
+ The device (CPU or CUDA) used for inference.
107
+ """
108
+ model_path = "AnjaliSarawgi/model-oct"
109
+ # In an offline environment the HF token is None; if you wish
110
+ # to use a private model you can set HF_TOKEN in your environment.
111
+ hf_token = os.environ.get("HF_TOKEN")
112
+ model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
113
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
114
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten", token=None)
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ model.to(device).eval()
117
+ return model, tokenizer, processor.feature_extractor, device
118
+
119
+
120
+ # ----------------------------------------------------------------------
121
+ # Utility functions
122
+ #
123
+
124
+ def clean_text(text: str) -> str:
125
+ """Normalize and collapse whitespace from a decoded string.
126
+
127
+ Parameters
128
+ ----------
129
+ text : str
130
+ The raw decoded string from the model.
131
+
132
+ Returns
133
+ -------
134
+ str
135
+ The cleaned string with Unicode normalization and whitespace
136
+ removed. All whitespace characters are stripped since the
137
+ predictions are later tokenized at the akshara (syllable) level.
138
+ """
139
+ text = unicodedata.normalize("NFC", text)
140
+ text = CLEANUP.sub("", text)
141
+ return re.sub(r"\s+", "", text)
142
+
143
+
144
+ def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
145
+ """Resize the image so that its longest side equals max_side.
146
+
147
+ Parameters
148
+ ----------
149
+ image : PIL.Image
150
+ Input image.
151
+ max_side : int, optional
152
+ Maximum allowed size for the longest side of the image.
153
+
154
+ Returns
155
+ -------
156
+ PIL.Image
157
+ The resized image.
158
+ """
159
+ img = image.convert("RGB")
160
+ w, h = img.size
161
+ if max(w, h) > max_side:
162
+ img.thumbnail((max_side, max_side), Image.LANCZOS)
163
+ return img
164
+
165
+
166
+ def get_amp_ctx():
167
+ """Return the appropriate context manager for automatic mixed precision."""
168
+ return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext
169
+
170
+
171
+ # ----------------------------------------------------------------------
172
+ # XML parsing and segmentation
173
+ #
174
+ def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
175
+ """Parse ALTO or PAGE XML to extract bounding boxes.
176
+
177
+ Parameters
178
+ ----------
179
+ xml_bytes : bytes
180
+ Raw XML bytes.
181
+ level : {"block", "line", "word"}, optional
182
+ The segmentation level to extract. For OCR we use "line".
183
+ image_size : tuple or None
184
+ If provided, image_size=(width, height) allows rescaling
185
+ coordinates to match the actual image. ALTO files often store
186
+ absolute page sizes that differ from the image dimensions.
187
+
188
+ Returns
189
+ -------
190
+ list of dict
191
+ Each dict represents a bounding box with keys:
192
+ - "bbox": [x1, y1, x2, y2]
193
+ - "points": list of (x, y) if polygonal coordinates exist
194
+ - "id": line identifier (string)
195
+ - "label": the type of element (e.g. TextLine)
196
+ """
197
+ def _strip_ns(elem):
198
+ for e in elem.iter():
199
+ if isinstance(e.tag, str) and e.tag.startswith("{"):
200
+ e.tag = e.tag.split("}", 1)[1]
201
+
202
+ root = ET.parse(io.BytesIO(xml_bytes)).getroot()
203
+ _strip_ns(root)
204
+ boxes = []
205
+
206
+ # ALTO format handling
207
+ if root.tag.lower() == "alto":
208
+ tag_map = {"block": "TextBlock", "line": "TextLine", "word": "String"}
209
+ tag = tag_map.get(level, "TextLine")
210
+ page_el = root.find(".//Page")
211
+ page_w = page_h = None
212
+ if page_el is not None:
213
+ try:
214
+ page_w = float(page_el.get("WIDTH") or 0)
215
+ page_h = float(page_el.get("HEIGHT") or 0)
216
+ except Exception:
217
+ page_w = page_h = None
218
+ sx = sy = 1.0
219
+ if image_size and page_w and page_h:
220
+ img_w, img_h = image_size
221
+ sx = (img_w / page_w) if page_w else 1.0
222
+ sy = (img_h / page_h) if page_h else 1.0
223
+ for el in root.findall(f".//{tag}"):
224
+ poly = el.find(".//Shape/Polygon")
225
+ got_box = False
226
+ pts = None
227
+ if poly is not None and poly.get("POINTS"):
228
+ raw = poly.get("POINTS").strip()
229
+ tokens = re.split(r"[ ,]+", raw)
230
+ nums = []
231
+ for t in tokens:
232
+ try:
233
+ nums.append(float(t))
234
+ except Exception:
235
+ pass
236
+ pts = []
237
+ if len(nums) >= 6 and len(nums) % 2 == 0:
238
+ for i in range(0, len(nums), 2):
239
+ pts.append((nums[i] * sx, nums[i + 1] * sy))
240
+ if pts:
241
+ xs = [p[0] for p in pts]
242
+ ys = [p[1] for p in pts]
243
+ x1, x2 = int(min(xs)), int(max(xs))
244
+ y1, y2 = int(min(ys)), int(max(ys))
245
+ got_box = (x2 > x1 and y2 > y1)
246
+ if not got_box:
247
+ try:
248
+ hpos = float(el.get("HPOS", 0)) * sx
249
+ vpos = float(el.get("VPOS", 0)) * sy
250
+ width = float(el.get("WIDTH", 0)) * sx
251
+ height = float(el.get("HEIGHT", 0)) * sy
252
+ x1, y1 = int(hpos), int(vpos)
253
+ x2, y2 = int(hpos + width), int(vpos + height)
254
+ except Exception:
255
+ continue
256
+ if x2 <= x1 or y2 <= y1:
257
+ continue
258
+ label = tag if tag != "String" else (el.get("CONTENT") or "String")
259
+ boxes.append(
260
+ {
261
+ "label": label,
262
+ "bbox": [x1, y1, x2, y2],
263
+ "source": "alto",
264
+ "id": el.get("ID", ""),
265
+ **({"points": pts} if pts else {}),
266
+ }
267
+ )
268
+ return boxes
269
+
270
+ # PAGE XML handling
271
+ for region in root.findall(".//TextRegion"):
272
+ coords = region.find(".//Coords")
273
+ pts_attr = coords.get("points") if coords is not None else None
274
+ if not pts_attr:
275
+ continue
276
+ pts = []
277
+ for token in pts_attr.strip().split():
278
+ if "," in token:
279
+ xx, yy = token.split(",", 1)
280
+ try:
281
+ pts.append((float(xx), float(yy)))
282
+ except Exception:
283
+ pass
284
+ if not pts:
285
+ continue
286
+ xs = [p[0] for p in pts]
287
+ ys = [p[1] for p in pts]
288
+ x1, x2 = int(min(xs)), int(max(xs))
289
+ y1, y2 = int(min(ys)), int(max(ys))
290
+ if x2 > x1 and y2 > y1:
291
+ boxes.append(
292
+ {
293
+ "label": "TextRegion",
294
+ "bbox": [x1, y1, x2, y2],
295
+ "source": "page",
296
+ "id": region.get("id", ""),
297
+ }
298
+ )
299
+ if boxes:
300
+ return boxes
301
+ # Fallback: Pascal VOC
302
+ for obj in root.findall(".//object"):
303
+ bb = obj.find("bndbox")
304
+ if bb is None:
305
+ continue
306
+ try:
307
+ xmin = int(float(bb.findtext("xmin")))
308
+ ymin = int(float(bb.findtext("ymin")))
309
+ xmax = int(float(bb.findtext("xmax")))
310
+ ymax = int(float(bb.findtext("ymax")))
311
+ if xmax > xmin and ymax > ymin:
312
+ boxes.append(
313
+ {
314
+ "label": (obj.findtext("name") or "region").strip(),
315
+ "bbox": [xmin, ymin, xmax, ymax],
316
+ "source": "voc",
317
+ "id": obj.findtext("name") or "",
318
+ }
319
+ )
320
+ except Exception:
321
+ pass
322
+ return boxes
323
+
324
+
325
+ def sort_boxes_reading_order(boxes, y_tol: int = 10):
326
+ """Sort bounding boxes top‑to‑bottom then left‑to‑right."""
327
+ def key(b):
328
+ x1, y1, x2, y2 = b["bbox"]
329
+ return (round(y1 / max(1, y_tol)), y1, x1)
330
+ return sorted(boxes, key=key)
331
+
332
+
333
+ def draw_boxes(img: Image.Image, boxes):
334
+ """Overlay semi‑transparent red polygons or rectangles on an image.
335
+
336
+ Parameters
337
+ ----------
338
+ img : PIL.Image
339
+ The base image.
340
+ boxes : list of dict
341
+ Segmentation boxes with either 'points' or 'bbox' keys.
342
+
343
+ Returns
344
+ -------
345
+ PIL.Image
346
+ An image with red overlays marking each box. Boxes are numbered
347
+ starting from 1.
348
+ """
349
+ base = img.convert("RGBA")
350
+ overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
351
+ draw = ImageDraw.Draw(overlay)
352
+ thickness = max(3, min(base.size) // 200)
353
+ for i, b in enumerate(boxes, 1):
354
+ if "points" in b and b["points"]:
355
+ pts = [(int(x), int(y)) for x, y in b["points"]]
356
+ draw.polygon(pts, outline=(255, 0, 0, 255), fill=(255, 0, 0, 64))
357
+ xs = [p[0] for p in pts]
358
+ ys = [p[1] for p in pts]
359
+ x1, y1 = min(xs), min(ys)
360
+ else:
361
+ x1, y1, x2, y2 = map(int, b["bbox"])
362
+ draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0, 255), width=thickness, fill=(255, 0, 0, 64))
363
+ tag_w, tag_h = 40, 24
364
+ draw.rectangle([x1, y1, x1 + tag_w, y1 + tag_h], fill=(255, 0, 0, 190))
365
+ draw.text((x1 + 6, y1 + 4), str(i), fill=(255, 255, 255, 255))
366
+ return Image.alpha_composite(base, overlay).convert("RGB")
367
+
368
+
369
+ # ----------------------------------------------------------------------
370
+ # OCR inference per line
371
+ #
372
+ def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
373
+ """Run the model on a single cropped line and return predictions and scores.
374
+
375
+ This helper wraps the model.generate call to obtain per‑token
376
+ probabilities and derives a DataFrame summarizing each decoding step.
377
+
378
+ Parameters
379
+ ----------
380
+ image : PIL.Image
381
+ Cropped segment to process.
382
+ line_id : int, optional
383
+ Identifier used in the output DataFrame.
384
+ topk : int, optional
385
+ Number of alternative tokens to keep for each decoding position.
386
+
387
+ Returns
388
+ -------
389
+ decoded_text : str
390
+ Cleaned predicted string for the line.
391
+ df : pandas.DataFrame
392
+ Table with one row per generated token containing the following
393
+ columns: line_id, seq_pos, token_id, token, confidence,
394
+ rel_prob, entropy, gap12, alt_tokens, alt_probs.
395
+ """
396
+ model, tokenizer, feature_extractor, device = load_model()
397
+ img = prepare_image(image)
398
+ pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
399
+ amp_ctx = get_amp_ctx()
400
+ with torch.inference_mode(), amp_ctx():
401
+ try:
402
+ out = model.generate(
403
+ pixel_values,
404
+ max_length=MAX_LEN,
405
+ num_beams=5,
406
+ do_sample=False,
407
+ return_dict_in_generate=True,
408
+ output_scores=True,
409
+ use_cache=True,
410
+ eos_token_id=tokenizer.eos_token_id,
411
+ )
412
+ except RuntimeError as e:
413
+ # In case of GPU OOM, fall back to beam=1 without scores
414
+ if "out of memory" in str(e).lower():
415
+ out = model.generate(
416
+ pixel_values,
417
+ max_length=MAX_LEN,
418
+ num_beams=1,
419
+ do_sample=False,
420
+ return_dict_in_generate=True,
421
+ output_scores=False,
422
+ use_cache=True,
423
+ eos_token_id=tokenizer.eos_token_id,
424
+ )
425
+ else:
426
+ raise
427
+
428
+ seq = out.sequences[0]
429
+ decoded_text = clean_text(tokenizer.decode(seq, skip_special_tokens=True))
430
+ tokens_rows = []
431
+ # out.scores[i] gives logits for the i+1 token of seq
432
+ for step, (logits, tgt) in enumerate(zip(out.scores, seq[1:]), start=1):
433
+ probs = torch.softmax(logits[0].float().cpu(), dim=-1)
434
+ tgt_id = int(tgt.item())
435
+ conf = float(probs[tgt_id].item())
436
+ tk_vals, tk_idx = torch.topk(probs, k=min(topk, probs.shape[0]))
437
+ tk_idx = tk_idx.tolist()
438
+ tk_vals = tk_vals.tolist()
439
+ if tgt_id in tk_idx:
440
+ j = tk_idx.index(tgt_id)
441
+ tk_idx.pop(j)
442
+ tk_vals.pop(j)
443
+ alt_ids = [tgt_id] + tk_idx[: topk - 1]
444
+ alt_ps = [conf] + tk_vals[: topk - 1]
445
+ alt_tokens = [tokenizer.decode([i], skip_special_tokens=True) for i in alt_ids]
446
+ entropy = float((-probs * (probs.clamp_min(1e-12).log())).sum().item())
447
+ gap12 = float(alt_ps[0] - (alt_ps[1] if len(alt_ps) > 1 else 0.0))
448
+ rel_prob = float((alt_ps[1] / alt_ps[0]) if (len(alt_ps) > 1 and alt_ps[0] > 0) else 0.0)
449
+ tokens_rows.append(
450
+ {
451
+ "line_id": line_id,
452
+ "seq_pos": step,
453
+ "token_id": tgt_id,
454
+ "token": alt_tokens[0],
455
+ "confidence": conf,
456
+ "rel_prob": rel_prob,
457
+ "entropy": entropy,
458
+ "gap12": gap12,
459
+ "alt_tokens": "|".join(alt_tokens),
460
+ "alt_probs": "|".join([f"{p:.6f}" for p in alt_ps]),
461
+ }
462
+ )
463
+ del probs
464
+ df = pd.DataFrame(
465
+ tokens_rows,
466
+ columns=[
467
+ "line_id",
468
+ "seq_pos",
469
+ "token_id",
470
+ "token",
471
+ "confidence",
472
+ "rel_prob",
473
+ "entropy",
474
+ "gap12",
475
+ "alt_tokens",
476
+ "alt_probs",
477
+ ],
478
+ )
479
+ return decoded_text, df
480
+
481
+
482
+ # ----------------------------------------------------------------------
483
+ # Text splitting into aksharas (syllable units) for highlighting
484
+ #
485
+ # The following regex and helper functions split a Devanagari string into
486
+ # aksharas. This is necessary to map model tokens back to spans of
487
+ # characters when highlighting uncertain predictions.
488
+
489
+ DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range
490
+ INDEP_VOW = "\u0904-\u0914" # independent vowels
491
+ NUKTA = "\u093C" # nukta
492
+ VIRAMA = "\u094D" # halant/virama
493
+ MATRAS = "\u093A-\u094C" # dependent vowel signs
494
+ BINDUS = "\u0901\u0902\u0903" # chandrabindu, anusvara, visarga
495
+ AKSHARA_RE = re.compile(
496
+ rf"(?:"
497
+ rf"(?:[{DEV_CONS}]{NUKTA}?)(?:{VIRAMA}(?:[{DEV_CONS}]{NUKTA}?))*" # consonant cluster
498
+ rf"(?:[{MATRAS}])?" # optional matra
499
+ rf"(?:[{BINDUS}])?" # optional bindu/visarga
500
+ rf"|"
501
+ rf"(?:[{INDEP_VOW}](?:[{BINDUS}])?)" # independent vowel (+bindu)
502
+ rf")",
503
+ flags=re.UNICODE,
504
+ )
505
+
506
+
507
+ def split_aksharas(s: str):
508
+ """Split a string into Devanagari aksharas and return spans."""
509
+ spans = []
510
+ i = 0
511
+ while i < len(s):
512
+ m = AKSHARA_RE.match(s, i)
513
+ if m and m.end() > i:
514
+ spans.append((m.start(), m.end()))
515
+ i = m.end()
516
+ else:
517
+ spans.append((i, i + 1))
518
+ i += 1
519
+ return [s[a:b] for (a, b) in spans], spans
520
+
521
+
522
+ def parse_alt_probs(s: str):
523
+ try:
524
+ return [float(x) for x in (s or "").split("|") if x != ""]
525
+ except Exception:
526
+ return []
527
+
528
+
529
+ def parse_alt_tokens(s: str):
530
+ return [(t if t is not None else "") for t in (s or "").split("|")]
531
+
532
+
533
+ def highlight_tokens_with_tooltips(
534
+ line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
535
+ ) -> str:
536
+ """Insert HTML spans around tokens whose chosen metric exceeds threshold.
537
+
538
+ The metric column can be "rel_prob" (relative probability) or
539
+ "entropy". Tokens with a value strictly greater than red_threshold
540
+ will be wrapped in a span with a tooltip listing alternative
541
+ predictions and their probabilities.
542
+
543
+ Parameters
544
+ ----------
545
+ line_text : str
546
+ The cleaned line prediction.
547
+ df_tok : pandas.DataFrame
548
+ DataFrame of token statistics for the corresponding line.
549
+ red_threshold : float
550
+ Values above this threshold will be highlighted.
551
+ metric_column : str
552
+ Column name in df_tok used for thresholding.
553
+
554
+ Returns
555
+ -------
556
+ str
557
+ An HTML string with <span> elements inserted.
558
+ """
559
+ aks, spans = split_aksharas(line_text)
560
+ joined = "".join(aks)
561
+ used_ranges = []
562
+ insertions = []
563
+ for _, row in df_tok.iterrows():
564
+ token = row.get("token", "").strip()
565
+ try:
566
+ val = float(row.get(metric_column, 0))
567
+ except Exception:
568
+ continue
569
+ if val <= red_threshold or not token:
570
+ continue
571
+ # Try finding the token in the joined akshara sequence
572
+ start_char_idx = joined.find(token)
573
+ if start_char_idx == -1:
574
+ continue
575
+ # Locate matching akshara span
576
+ ak_start = ak_end = None
577
+ cum_len = 0
578
+ for i, ak in enumerate(aks):
579
+ next_len = cum_len + len(ak)
580
+ if cum_len <= start_char_idx < next_len:
581
+ ak_start = i
582
+ if cum_len < start_char_idx + len(token) <= next_len:
583
+ ak_end = i + 1
584
+ break
585
+ cum_len = next_len
586
+ if ak_start is None or ak_end is None:
587
+ continue
588
+ # Avoid overlaps
589
+ if any(r[0] < ak_end and ak_start < r[1] for r in used_ranges):
590
+ continue
591
+ used_ranges.append((ak_start, ak_end))
592
+ # Character positions
593
+ char_start = spans[ak_start][0]
594
+ char_end = spans[ak_end - 1][1]
595
+ # Build tooltip content
596
+ alt_toks = row.get("alt_tokens", "").split("|")
597
+ alt_probs = row.get("alt_probs", "").split("|")
598
+ tooltip_lines = []
599
+ for t, p in zip(alt_toks, alt_probs):
600
+ try:
601
+ prob = float(p)
602
+ except Exception:
603
+ prob = 0.0
604
+ tooltip_lines.append(f"{_html_escape(t)}: {prob:.3f}")
605
+ tooltip = "\n".join(tooltip_lines)
606
+ token_str = _html_escape(line_text[char_start:char_end])
607
+ html_token = f"<span class='ocr-token' data-tooltip='{_html_escape(tooltip)}'>{token_str}</span>"
608
+ insertions.append((char_start, char_end, html_token))
609
+ if not insertions:
610
+ return _html_escape(line_text)
611
+ insertions.sort()
612
+ out_parts = []
613
+ last_idx = 0
614
+ for s, e, html_tok in insertions:
615
+ out_parts.append(_html_escape(line_text[last_idx:s]))
616
+ out_parts.append(html_tok)
617
+ last_idx = e
618
+ out_parts.append(_html_escape(line_text[last_idx:]))
619
+ return "".join(out_parts)
620
+
621
+
622
+ def _html_escape(s: str) -> str:
623
+ return (
624
+ s.replace("&", "&amp;")
625
+ .replace("<", "&lt;")
626
+ .replace(">", "&gt;")
627
+ .replace("\"", "&quot;")
628
+ .replace("'", "&#x27;")
629
+ )
630
+
631
+
632
+ # ----------------------------------------------------------------------
633
+ # Main OCR wrapper for Gradio
634
+ #
635
+ def run_ocr(
636
+ image: np.ndarray | None,
637
+ xml_file: tuple | None,
638
+ apply_gray: bool,
639
+ apply_bin: bool,
640
+ highlight_metric: str,
641
+ ):
642
+ """Run the OCR pipeline on user inputs and return results for Gradio.
643
+
644
+ Parameters
645
+ ----------
646
+ image : numpy.ndarray or None
647
+ The uploaded image converted to a NumPy array by Gradio. If
648
+ None, the function returns empty results.
649
+ xml_file : tuple or None
650
+ A tuple representing the uploaded XML file as provided by
651
+ gr.File. The first element is the file name and the second is
652
+ bytes. If None, no segmentation is applied and the entire
653
+ image is processed as a single line.
654
+ apply_gray : bool
655
+ Whether to convert the image to grayscale before OCR.
656
+ apply_bin : bool
657
+ Whether to apply binarization (Otsu threshold) before OCR. If
658
+ selected, grayscale conversion is applied first automatically.
659
+ highlight_metric : str
660
+ Which metric to use for highlighting ("Relative Probability" or
661
+ "Entropy").
662
+
663
+ Returns
664
+ -------
665
+ overlay_img : PIL.Image or None
666
+ Image with segmentation boxes drawn. None if no input image.
667
+ predictions_html : str
668
+ HTML formatted predicted text with highlighted tokens.
669
+ df_scores : pandas.DataFrame or None
670
+ DataFrame of per‑token statistics. None if no input image.
671
+ txt_file_path : str or None
672
+ Path to a temporary .txt file containing the plain predicted text.
673
+ csv_file_path : str or None
674
+ Path to a temporary CSV file containing the extended token scores.
675
+ """
676
+ if image is None:
677
+ return None, "", None, None, None
678
+ # Convert the numpy array to a PIL image
679
+ pil_img = Image.fromarray(image).convert("RGB")
680
+ # Apply preprocessing as requested
681
+ if apply_gray:
682
+ pil_img = pil_img.convert("L").convert("RGB")
683
+ if apply_bin:
684
+ img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
685
+ _, bin_img = cv2.threshold(img_cv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
686
+ pil_img = Image.fromarray(bin_img).convert("RGB")
687
+ # Parse segmentation boxes if XML provided
688
+ boxes: list = []
689
+ if xml_file:
690
+ # Determine the correct way to extract bytes from the uploaded file.
691
+ xml_bytes = None
692
+ # If gr.File is configured with type="binary", xml_file will be raw bytes.
693
+ if isinstance(xml_file, (bytes, bytearray)):
694
+ xml_bytes = bytes(xml_file)
695
+ # When type="filepath", xml_file would be a str path.
696
+ elif isinstance(xml_file, str):
697
+ try:
698
+ with open(xml_file, "rb") as f:
699
+ xml_bytes = f.read()
700
+ except Exception:
701
+ xml_bytes = None
702
+ # If a temporary file object is passed in, read its contents.
703
+ elif hasattr(xml_file, "read"):
704
+ try:
705
+ xml_bytes = xml_file.read()
706
+ except Exception:
707
+ xml_bytes = None
708
+ # If xml_file is a dictionary from Gradio (not expected with type="binary"),
709
+ # attempt to extract the data key.
710
+ elif isinstance(xml_file, dict) and "data" in xml_file:
711
+ xml_bytes = xml_file.get("data")
712
+ if xml_bytes:
713
+ try:
714
+ boxes = parse_boxes_from_xml(xml_bytes, level="line", image_size=pil_img.size)
715
+ boxes = sort_boxes_reading_order(boxes)[:MAX_LINES]
716
+ except Exception:
717
+ boxes = []
718
+ # Run OCR for each segmented line or the whole image
719
+ dfs = []
720
+ concatenated_parts = []
721
+ line_text_by_id = {}
722
+ if boxes:
723
+ pad = 2
724
+ for idx, b in enumerate(boxes, 1):
725
+ # Create a tight crop around the line
726
+ if "points" in b:
727
+ pts = b["points"]
728
+ mask = Image.new("L", pil_img.size, 0)
729
+ ImageDraw.Draw(mask).polygon(pts, outline=1, fill=255)
730
+ seg_img = Image.new("RGB", pil_img.size, (255, 255, 255))
731
+ seg_img.paste(pil_img, mask=mask)
732
+ xs = [x for x, y in pts]
733
+ ys = [y for x, y in pts]
734
+ x1 = max(0, int(min(xs) - pad))
735
+ y1 = max(0, int(min(ys) - pad))
736
+ x2 = min(pil_img.width, int(max(xs) + pad))
737
+ y2 = min(pil_img.height, int(max(ys) + pad))
738
+ crop = seg_img.crop((x1, y1, x2, y2))
739
+ else:
740
+ x1, y1, x2, y2 = b["bbox"]
741
+ x1p = max(0, x1 - pad)
742
+ y1p = max(0, y1 - pad)
743
+ x2p = min(pil_img.width, x2 + pad)
744
+ y2p = min(pil_img.height, y2 + pad)
745
+ crop = pil_img.crop((x1p, y1p, x2p, y2p))
746
+ # Run inference on the crop
747
+ seg_text, df_tok = predict_and_score_once(crop, line_id=idx, topk=TOPK)
748
+ seg_text = clean_text(seg_text)
749
+ # Choose metric
750
+ if highlight_metric == "Relative Probability":
751
+ red_threshold = REL_PROB_TH
752
+ metric_col = "rel_prob"
753
+ else:
754
+ red_threshold = 0.10 # heuristic threshold for entropy
755
+ metric_col = "entropy"
756
+ # Highlight uncertain tokens
757
+ seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_tok, red_threshold, metric_col)
758
+ concatenated_parts.append(seg_text_flagged)
759
+ df_tok["line_id"] = idx
760
+ dfs.append(df_tok)
761
+ line_text_by_id[idx] = seg_text_flagged
762
+ predicted_html = "<br>".join(concatenated_parts).strip()
763
+ df_all = pd.concat(dfs, ignore_index=True)
764
+ else:
765
+ # Single pass on the whole image
766
+ seg_text, df_all = predict_and_score_once(pil_img, line_id=1, topk=TOPK)
767
+ seg_text = clean_text(seg_text)
768
+ if highlight_metric == "Relative Probability":
769
+ red_threshold = REL_PROB_TH
770
+ metric_col = "rel_prob"
771
+ else:
772
+ red_threshold = 0.10
773
+ metric_col = "entropy"
774
+ seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_all, red_threshold, metric_col)
775
+ predicted_html = seg_text_flagged
776
+ line_text_by_id[1] = seg_text_flagged
777
+ # Draw overlay image
778
+ overlay_img = draw_boxes(pil_img, boxes) if boxes else pil_img
779
+ # Create downloads
780
+ df_all = df_all.copy()
781
+ # Drop the last empty token per line to tidy up output
782
+ df_all.sort_values(["line_id", "seq_pos"], inplace=True)
783
+ to_drop = []
784
+ for line_id, group in df_all.groupby("line_id"):
785
+ if group.iloc[-1]["token"].strip() == "":
786
+ to_drop.append(group.index[-1])
787
+ df_all = df_all.drop(index=to_drop)
788
+ # Prepare plain text by stripping HTML tags and replacing <br>
789
+ plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n"))
790
+ # Write temporary files
791
+ txt_path = None
792
+ csv_path = None
793
+ try:
794
+ txt_fd = io.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
795
+ txt_fd.write(plain_text)
796
+ txt_fd.flush()
797
+ txt_path = txt_fd.name
798
+ txt_fd.close()
799
+ except Exception:
800
+ txt_path = None
801
+ try:
802
+ csv_fd = io.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8")
803
+ df_all.to_csv(csv_fd, index=False)
804
+ csv_fd.flush()
805
+ csv_path = csv_fd.name
806
+ csv_fd.close()
807
+ except Exception:
808
+ csv_path = None
809
+ return overlay_img, predicted_html, df_all, txt_path, csv_path
810
+
811
+
812
+ # ----------------------------------------------------------------------
813
+ # Build Gradio Interface
814
+ #
815
+ def create_gradio_interface():
816
+ """Create and return the Gradio Blocks interface."""
817
+ with gr.Blocks(title="Old Nepali HTR") as demo:
818
+ gr.Markdown("""# Old Nepali HTR (Gradio)\n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
819
+ with gr.Row():
820
+ image_input = gr.Image(type="numpy", label="Upload Image")
821
+ # When used as an input, gr.File returns either a file path or bytes
822
+ # depending on the `type` parameter. By setting type="binary" we
823
+ # ensure that the XML content is passed directly as bytes to the
824
+ # callback, avoiding the need to reopen a temporary file.
825
+ xml_input = gr.File(
826
+ label="Upload segmentation XML (optional)",
827
+ file_count="single",
828
+ type="binary",
829
+ file_types=[".xml"],
830
+ )
831
+ with gr.Row():
832
+ apply_gray_checkbox = gr.Checkbox(label="Convert to Grayscale", value=False)
833
+ apply_bin_checkbox = gr.Checkbox(label="Binarize", value=False)
834
+ metric_radio = gr.Radio([
835
+ "Relative Probability",
836
+ "Entropy",
837
+ ], label="Highlight tokens by", value="Relative Probability")
838
+ run_btn = gr.Button("Run OCR")
839
+ # Outputs
840
+ overlay_output = gr.Image(label="Detected Regions")
841
+ predictions_output = gr.HTML(label="Predictions (HTML)")
842
+ df_output = gr.DataFrame(label="Token Scores", interactive=False)
843
+ # Separate file outputs for the OCR prediction, token scores and edited text.
844
+ ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)")
845
+ ocr_csv_output = gr.File(label="Download Token Scores (.csv)")
846
+ edited_txt_output = gr.File(label="Download edited text (.txt)")
847
+
848
+ # Editable text area
849
+ edited_text = gr.Textbox(
850
+ label="Edit full predicted text", lines=8, interactive=True
851
+ )
852
+ download_edited_btn = gr.Button("Download edited text")
853
+
854
+ # Callback for OCR
855
+ def on_run(image, xml, gray, binarize, metric):
856
+ return run_ocr(image, xml, gray, binarize, metric)
857
+
858
+ run_btn.click(
859
+ fn=on_run,
860
+ inputs=[image_input, xml_input, apply_gray_checkbox, apply_bin_checkbox, metric_radio],
861
+ outputs=[overlay_output, predictions_output, df_output, ocr_txt_output, ocr_csv_output],
862
+ )
863
+
864
+ # Populate editable text with plain text from predictions
865
+ def update_edited_text(pred_html):
866
+ plain = re.sub(r"<[^>]*>", "", (pred_html or "").replace("<br>", "\n"))
867
+ return plain
868
+
869
+ predictions_output.change(
870
+ fn=update_edited_text,
871
+ inputs=predictions_output,
872
+ outputs=edited_text,
873
+ )
874
+
875
+ # Download edited text by writing to a temporary file
876
+ def download_edited(txt):
877
+ if not txt:
878
+ return None
879
+ try:
880
+ fd = io.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
881
+ fd.write(txt)
882
+ fd.flush()
883
+ path = fd.name
884
+ fd.close()
885
+ return path
886
+ except Exception:
887
+ return None
888
+
889
+ download_edited_btn.click(
890
+ fn=download_edited,
891
+ inputs=edited_text,
892
+ outputs=edited_txt_output,
893
+ )
894
+ return demo
895
+
896
+
897
+
898
+ iface = create_gradio_interface()
899
+ iface.launch()