AnjaliSarawgi commited on
Commit
6c8fcdd
·
1 Parent(s): 8cb1f4a

Add application file

Browse files
Files changed (1) hide show
  1. app.py +304 -27
app.py CHANGED
@@ -1,3 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import io
3
  import os
@@ -22,18 +46,68 @@ from transformers import (
22
  from matplotlib import cm
23
  import gradio as gr
24
 
 
 
 
 
 
 
 
 
 
25
  MAX_LEN: int = 128
 
 
26
  TOPK: int = 3
 
 
 
 
27
  MAX_LINES: int = 120
 
 
 
 
28
  RESIZE_MAX_SIDE: int = 800
 
 
 
 
29
  REL_PROB_TH: float = 0.70
 
 
 
 
30
  CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")
 
 
31
  FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
32
 
33
 
 
 
 
 
 
 
34
  @lru_cache(maxsize=1)
35
  def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  model_path = "AnjaliSarawgi/model-oct"
 
 
37
  hf_token = os.environ.get("HF_TOKEN")
38
  model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
39
  tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
@@ -43,13 +117,45 @@ def load_model():
43
  return model, tokenizer, processor.feature_extractor, device
44
 
45
 
 
 
 
 
46
  def clean_text(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  text = unicodedata.normalize("NFC", text)
48
  text = CLEANUP.sub("", text)
49
  return re.sub(r"\s+", "", text)
50
 
51
 
52
  def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  img = image.convert("RGB")
54
  w, h = img.size
55
  if max(w, h) > max_side:
@@ -58,10 +164,36 @@ def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.
58
 
59
 
60
  def get_amp_ctx():
 
61
  return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext
62
 
 
 
 
63
  #
64
  def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def _strip_ns(elem):
66
  for e in elem.iter():
67
  if isinstance(e.tag, str) and e.tag.startswith("{"):
@@ -191,6 +323,7 @@ def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tupl
191
 
192
 
193
  def sort_boxes_reading_order(boxes, y_tol: int = 10):
 
194
  def key(b):
195
  x1, y1, x2, y2 = b["bbox"]
196
  return (round(y1 / max(1, y_tol)), y1, x1)
@@ -198,6 +331,21 @@ def sort_boxes_reading_order(boxes, y_tol: int = 10):
198
 
199
 
200
  def draw_boxes(img: Image.Image, boxes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  base = img.convert("RGBA")
202
  overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
203
  draw = ImageDraw.Draw(overlay)
@@ -218,7 +366,33 @@ def draw_boxes(img: Image.Image, boxes):
218
  return Image.alpha_composite(base, overlay).convert("RGB")
219
 
220
 
 
 
 
221
  def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  model, tokenizer, feature_extractor, device = load_model()
223
  img = prepare_image(image)
224
  pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
@@ -305,6 +479,13 @@ def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOP
305
  return decoded_text, df
306
 
307
 
 
 
 
 
 
 
 
308
  DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range
309
  INDEP_VOW = "\u0904-\u0914" # independent vowels
310
  NUKTA = "\u093C" # nukta
@@ -352,23 +533,65 @@ def parse_alt_tokens(s: str):
352
  def highlight_tokens_with_tooltips(
353
  line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
354
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  aks, spans = split_aksharas(line_text)
356
  joined = "".join(aks)
357
- used_ranges = []
358
- insertions = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  for _, row in df_tok.iterrows():
360
- token = row.get("token", "").strip()
 
 
 
361
  try:
362
  val = float(row.get(metric_column, 0))
363
  except Exception:
364
- continue
365
- if val <= red_threshold or not token:
366
- continue
367
- # Try finding the token in the joined akshara sequence
368
  start_char_idx = joined.find(token)
369
  if start_char_idx == -1:
370
  continue
371
- # Locate matching akshara span
372
  ak_start = ak_end = None
373
  cum_len = 0
374
  for i, ak in enumerate(aks):
@@ -381,17 +604,16 @@ def highlight_tokens_with_tooltips(
381
  cum_len = next_len
382
  if ak_start is None or ak_end is None:
383
  continue
384
- # Avoid overlaps
385
  if any(r[0] < ak_end and ak_start < r[1] for r in used_ranges):
386
  continue
387
  used_ranges.append((ak_start, ak_end))
388
- # Character positions
389
  char_start = spans[ak_start][0]
390
  char_end = spans[ak_end - 1][1]
391
- # Build tooltip content
392
  alt_toks = row.get("alt_tokens", "").split("|")
393
  alt_probs = row.get("alt_probs", "").split("|")
394
- tooltip_lines = []
395
  for t, p in zip(alt_toks, alt_probs):
396
  try:
397
  prob = float(p)
@@ -400,12 +622,16 @@ def highlight_tokens_with_tooltips(
400
  tooltip_lines.append(f"{_html_escape(t)}: {prob:.3f}")
401
  tooltip = "\n".join(tooltip_lines)
402
  token_str = _html_escape(line_text[char_start:char_end])
403
- html_token = f"<span class='ocr-token' data-tooltip='{_html_escape(tooltip)}'>{token_str}</span>"
 
 
 
404
  insertions.append((char_start, char_end, html_token))
 
405
  if not insertions:
406
  return _html_escape(line_text)
407
  insertions.sort()
408
- out_parts = []
409
  last_idx = 0
410
  for s, e, html_tok in insertions:
411
  out_parts.append(_html_escape(line_text[last_idx:s]))
@@ -435,7 +661,40 @@ def run_ocr(
435
  apply_bin: bool,
436
  highlight_metric: str,
437
  ):
438
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  if image is None:
440
  return None, "", None, None, None
441
  # Convert the numpy array to a PIL image
@@ -548,6 +807,32 @@ def run_ocr(
548
  if group.iloc[-1]["token"].strip() == "":
549
  to_drop.append(group.index[-1])
550
  df_all = df_all.drop(index=to_drop)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  # Prepare plain text by stripping HTML tags and replacing <br>
552
  plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n"))
553
  # Write temporary files
@@ -569,10 +854,7 @@ def run_ocr(
569
  csv_fd.close()
570
  except Exception:
571
  csv_path = None
572
- # return overlay_img, predicted_html, df_all, txt_path, csv_path
573
- txt_bytes = plain_text.encode("utf-8")
574
- csv_bytes = df_all.to_csv(index=False).encode("utf-8")
575
- return overlay_img, predicted_html, df_all, txt_bytes, csv_bytes
576
 
577
 
578
  # ----------------------------------------------------------------------
@@ -581,7 +863,7 @@ def run_ocr(
581
  def create_gradio_interface():
582
  """Create and return the Gradio Blocks interface."""
583
  with gr.Blocks(title="Old Nepali HTR") as demo:
584
- gr.Markdown("""# Old Nepali HTR \n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
585
  with gr.Row():
586
  image_input = gr.Image(type="numpy", label="Upload Image")
587
  # When used as an input, gr.File returns either a file path or bytes
@@ -607,11 +889,8 @@ def create_gradio_interface():
607
  predictions_output = gr.HTML(label="Predictions (HTML)")
608
  df_output = gr.DataFrame(label="Token Scores", interactive=False)
609
  # Separate file outputs for the OCR prediction, token scores and edited text.
610
- # ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)")
611
- # ocr_csv_output = gr.File(label="Download Token Scores (.csv)")
612
- ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)", type="binary")
613
- ocr_csv_output = gr.File(label="Download Token Scores (.csv)", type="binary")
614
-
615
  edited_txt_output = gr.File(label="Download edited text (.txt)")
616
 
617
  # Editable text area
@@ -662,7 +941,5 @@ def create_gradio_interface():
662
  )
663
  return demo
664
 
665
-
666
-
667
  iface = create_gradio_interface()
668
  iface.launch()
 
1
+ """
2
+ Gradio application for performing OCR on scanned Old Nepali documents.
3
+
4
+ This script is a Gradio port of a Streamlit application originally built
5
+ to visualize and edit OCR output. It loads a pre‑trained model for
6
+ sequence decoding, accepts an input image (and optional segmentation
7
+ XML in ALTO format), performs OCR on segmented lines, highlights tokens
8
+ with low confidence and offers downloads of both the raw text and per
9
+ token scores.
10
+
11
+ The heavy lifting functions (model loading, pre‑processing, inference
12
+ and highlighting) are adapted directly from the Streamlit version. The
13
+ UI has been simplified for Gradio: users upload an image and optional
14
+ XML file, choose preprocessing steps and a highlight metric, then run
15
+ OCR. The results are displayed alongside the overlaid segmentation
16
+ boxes and a table of token scores. An editable textbox lets users
17
+ modify the predicted text before downloading it.
18
+
19
+ To run this app locally, install gradio (`pip install gradio`) and
20
+ execute this script with Python:
21
+
22
+ python gradio_app.py
23
+
24
+ """
25
 
26
  import io
27
  import os
 
46
  from matplotlib import cm
47
  import gradio as gr
48
 
49
+ # ----------------------------------------------------------------------
50
+ # Configuration
51
+ #
52
+ # These constants control various aspects of the OCR pipeline. You can
53
+ # adjust them to trade off accuracy, performance or output volume.
54
+
55
+ # The maximum number of tokens to decode for a single line. If your
56
+ # documents typically have longer lines you can increase this value, but
57
+ # beware that very long sequences may cause more memory usage.
58
  MAX_LEN: int = 128
59
+
60
+ # How many alternative tokens to keep when computing per–token statistics.
61
  TOPK: int = 3
62
+
63
+ # If an XML segmentation file is provided, only process the first
64
+ # MAX_LINES lines. This prevents huge documents from consuming
65
+ # excessive resources.
66
  MAX_LINES: int = 120
67
+
68
+ # Images are resized such that the longest side does not exceed this
69
+ # number of pixels before passing them to the OCR model. Increasing
70
+ # this value may improve accuracy at the cost of speed and memory.
71
  RESIZE_MAX_SIDE: int = 800
72
+
73
+ # Threshold used when highlighting tokens by relative probability. A
74
+ # ratio of Top2/Top1 greater than this value will cause the token to
75
+ # be highlighted in red.
76
  REL_PROB_TH: float = 0.70
77
+
78
+ # A regex used to clean up Unicode control characters before text
79
+ # normalization. Soft hyphens, zero width spaces and similar marks
80
+ # interfere with accurate token matching.
81
  CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")
82
+
83
+ # Default font path for rendering predictions directly on the image.
84
  FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
85
 
86
 
87
+ # ----------------------------------------------------------------------
88
+ # Model loading
89
+ #
90
+ # Loading the model and associated tokenizer/processor is slow. Use
91
+ # functools.lru_cache to ensure this only happens once per process.
92
+
93
  @lru_cache(maxsize=1)
94
  def load_model():
95
+ """Load the OCR model, tokenizer and feature extractor.
96
+
97
+ Returns
98
+ -------
99
+ model : VisionEncoderDecoderModel
100
+ The loaded model in evaluation mode.
101
+ tokenizer : PreTrainedTokenizerFast
102
+ Tokenizer corresponding to the decoder part of the model.
103
+ feature_extractor : callable
104
+ Feature extractor converting PIL images into model inputs.
105
+ device : torch.device
106
+ The device (CPU or CUDA) used for inference.
107
+ """
108
  model_path = "AnjaliSarawgi/model-oct"
109
+ # In an offline environment the HF token is None; if you wish
110
+ # to use a private model you can set HF_TOKEN in your environment.
111
  hf_token = os.environ.get("HF_TOKEN")
112
  model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
113
  tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
 
117
  return model, tokenizer, processor.feature_extractor, device
118
 
119
 
120
+ # ----------------------------------------------------------------------
121
+ # Utility functions
122
+ #
123
+
124
  def clean_text(text: str) -> str:
125
+ """Normalize and collapse whitespace from a decoded string.
126
+
127
+ Parameters
128
+ ----------
129
+ text : str
130
+ The raw decoded string from the model.
131
+
132
+ Returns
133
+ -------
134
+ str
135
+ The cleaned string with Unicode normalization and whitespace
136
+ removed. All whitespace characters are stripped since the
137
+ predictions are later tokenized at the akshara (syllable) level.
138
+ """
139
  text = unicodedata.normalize("NFC", text)
140
  text = CLEANUP.sub("", text)
141
  return re.sub(r"\s+", "", text)
142
 
143
 
144
  def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
145
+ """Resize the image so that its longest side equals max_side.
146
+
147
+ Parameters
148
+ ----------
149
+ image : PIL.Image
150
+ Input image.
151
+ max_side : int, optional
152
+ Maximum allowed size for the longest side of the image.
153
+
154
+ Returns
155
+ -------
156
+ PIL.Image
157
+ The resized image.
158
+ """
159
  img = image.convert("RGB")
160
  w, h = img.size
161
  if max(w, h) > max_side:
 
164
 
165
 
166
  def get_amp_ctx():
167
+ """Return the appropriate context manager for automatic mixed precision."""
168
  return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext
169
 
170
+
171
+ # ----------------------------------------------------------------------
172
+ # XML parsing and segmentation
173
  #
174
  def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
175
+ """Parse ALTO or PAGE XML to extract bounding boxes.
176
+
177
+ Parameters
178
+ ----------
179
+ xml_bytes : bytes
180
+ Raw XML bytes.
181
+ level : {"block", "line", "word"}, optional
182
+ The segmentation level to extract. For OCR we use "line".
183
+ image_size : tuple or None
184
+ If provided, image_size=(width, height) allows rescaling
185
+ coordinates to match the actual image. ALTO files often store
186
+ absolute page sizes that differ from the image dimensions.
187
+
188
+ Returns
189
+ -------
190
+ list of dict
191
+ Each dict represents a bounding box with keys:
192
+ - "bbox": [x1, y1, x2, y2]
193
+ - "points": list of (x, y) if polygonal coordinates exist
194
+ - "id": line identifier (string)
195
+ - "label": the type of element (e.g. TextLine)
196
+ """
197
  def _strip_ns(elem):
198
  for e in elem.iter():
199
  if isinstance(e.tag, str) and e.tag.startswith("{"):
 
323
 
324
 
325
  def sort_boxes_reading_order(boxes, y_tol: int = 10):
326
+ """Sort bounding boxes top‑to‑bottom then left‑to‑right."""
327
  def key(b):
328
  x1, y1, x2, y2 = b["bbox"]
329
  return (round(y1 / max(1, y_tol)), y1, x1)
 
331
 
332
 
333
  def draw_boxes(img: Image.Image, boxes):
334
+ """Overlay semi‑transparent red polygons or rectangles on an image.
335
+
336
+ Parameters
337
+ ----------
338
+ img : PIL.Image
339
+ The base image.
340
+ boxes : list of dict
341
+ Segmentation boxes with either 'points' or 'bbox' keys.
342
+
343
+ Returns
344
+ -------
345
+ PIL.Image
346
+ An image with red overlays marking each box. Boxes are numbered
347
+ starting from 1.
348
+ """
349
  base = img.convert("RGBA")
350
  overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
351
  draw = ImageDraw.Draw(overlay)
 
366
  return Image.alpha_composite(base, overlay).convert("RGB")
367
 
368
 
369
+ # ----------------------------------------------------------------------
370
+ # OCR inference per line
371
+ #
372
  def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
373
+ """Run the model on a single cropped line and return predictions and scores.
374
+
375
+ This helper wraps the model.generate call to obtain per‑token
376
+ probabilities and derives a DataFrame summarizing each decoding step.
377
+
378
+ Parameters
379
+ ----------
380
+ image : PIL.Image
381
+ Cropped segment to process.
382
+ line_id : int, optional
383
+ Identifier used in the output DataFrame.
384
+ topk : int, optional
385
+ Number of alternative tokens to keep for each decoding position.
386
+
387
+ Returns
388
+ -------
389
+ decoded_text : str
390
+ Cleaned predicted string for the line.
391
+ df : pandas.DataFrame
392
+ Table with one row per generated token containing the following
393
+ columns: line_id, seq_pos, token_id, token, confidence,
394
+ rel_prob, entropy, gap12, alt_tokens, alt_probs.
395
+ """
396
  model, tokenizer, feature_extractor, device = load_model()
397
  img = prepare_image(image)
398
  pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
 
479
  return decoded_text, df
480
 
481
 
482
+ # ----------------------------------------------------------------------
483
+ # Text splitting into aksharas (syllable units) for highlighting
484
+ #
485
+ # The following regex and helper functions split a Devanagari string into
486
+ # aksharas. This is necessary to map model tokens back to spans of
487
+ # characters when highlighting uncertain predictions.
488
+
489
  DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range
490
  INDEP_VOW = "\u0904-\u0914" # independent vowels
491
  NUKTA = "\u093C" # nukta
 
533
  def highlight_tokens_with_tooltips(
534
  line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
535
  ) -> str:
536
+ """Insert HTML spans around tokens whose chosen metric exceeds threshold.
537
+
538
+ The metric column can be "rel_prob" (relative probability) or
539
+ "entropy". Tokens with a value strictly greater than red_threshold
540
+ will be wrapped in a span with a tooltip listing alternative
541
+ predictions and their probabilities.
542
+
543
+ Parameters
544
+ ----------
545
+ line_text : str
546
+ The cleaned line prediction.
547
+ df_tok : pandas.DataFrame
548
+ DataFrame of token statistics for the corresponding line.
549
+ red_threshold : float
550
+ Values above this threshold will be highlighted.
551
+ metric_column : str
552
+ Column name in df_tok used for thresholding.
553
+
554
+ Returns
555
+ -------
556
+ str
557
+ An HTML string with <span> elements inserted.
558
+ """
559
  aks, spans = split_aksharas(line_text)
560
  joined = "".join(aks)
561
+ used_ranges: list = []
562
+ insertions: list = []
563
+ # Define colour classification depending on the metric
564
+ def color_class(val: float) -> str:
565
+ if metric_column == "rel_prob":
566
+ # Use the same thresholds as the original app: >0.7 red, >=0.05 yellow, otherwise green
567
+ if val >= 0.70:
568
+ return "token-red"
569
+ elif val >= 0.05:
570
+ return "token-yellow"
571
+ else:
572
+ return "token-green"
573
+ else:
574
+ # For entropy, high values indicate uncertainty. Thresholds here are heuristics.
575
+ if val >= 2.0:
576
+ return "token-red"
577
+ elif val >= 1.0:
578
+ return "token-yellow"
579
+ else:
580
+ return "token-green"
581
  for _, row in df_tok.iterrows():
582
+ token = str(row.get("token", "")).strip()
583
+ if not token:
584
+ continue
585
+ # Extract metric value for classification
586
  try:
587
  val = float(row.get(metric_column, 0))
588
  except Exception:
589
+ val = 0.0
590
+ # Find the first occurrence of the token in the joined akshara sequence
 
 
591
  start_char_idx = joined.find(token)
592
  if start_char_idx == -1:
593
  continue
594
+ # Locate corresponding akshara boundaries
595
  ak_start = ak_end = None
596
  cum_len = 0
597
  for i, ak in enumerate(aks):
 
604
  cum_len = next_len
605
  if ak_start is None or ak_end is None:
606
  continue
607
+ # Prevent overlapping spans
608
  if any(r[0] < ak_end and ak_start < r[1] for r in used_ranges):
609
  continue
610
  used_ranges.append((ak_start, ak_end))
 
611
  char_start = spans[ak_start][0]
612
  char_end = spans[ak_end - 1][1]
613
+ # Prepare tooltip content
614
  alt_toks = row.get("alt_tokens", "").split("|")
615
  alt_probs = row.get("alt_probs", "").split("|")
616
+ tooltip_lines: list = []
617
  for t, p in zip(alt_toks, alt_probs):
618
  try:
619
  prob = float(p)
 
622
  tooltip_lines.append(f"{_html_escape(t)}: {prob:.3f}")
623
  tooltip = "\n".join(tooltip_lines)
624
  token_str = _html_escape(line_text[char_start:char_end])
625
+ cls = color_class(val)
626
+ html_token = (
627
+ f"<span class='ocr-token {cls}' data-tooltip='{_html_escape(tooltip)}'>{token_str}</span>"
628
+ )
629
  insertions.append((char_start, char_end, html_token))
630
+ # If nothing was highlighted, return escaped original text
631
  if not insertions:
632
  return _html_escape(line_text)
633
  insertions.sort()
634
+ out_parts: list = []
635
  last_idx = 0
636
  for s, e, html_tok in insertions:
637
  out_parts.append(_html_escape(line_text[last_idx:s]))
 
661
  apply_bin: bool,
662
  highlight_metric: str,
663
  ):
664
+ """Run the OCR pipeline on user inputs and return results for Gradio.
665
+
666
+ Parameters
667
+ ----------
668
+ image : numpy.ndarray or None
669
+ The uploaded image converted to a NumPy array by Gradio. If
670
+ None, the function returns empty results.
671
+ xml_file : tuple or None
672
+ A tuple representing the uploaded XML file as provided by
673
+ gr.File. The first element is the file name and the second is
674
+ bytes. If None, no segmentation is applied and the entire
675
+ image is processed as a single line.
676
+ apply_gray : bool
677
+ Whether to convert the image to grayscale before OCR.
678
+ apply_bin : bool
679
+ Whether to apply binarization (Otsu threshold) before OCR. If
680
+ selected, grayscale conversion is applied first automatically.
681
+ highlight_metric : str
682
+ Which metric to use for highlighting ("Relative Probability" or
683
+ "Entropy").
684
+
685
+ Returns
686
+ -------
687
+ overlay_img : PIL.Image or None
688
+ Image with segmentation boxes drawn. None if no input image.
689
+ predictions_html : str
690
+ HTML formatted predicted text with highlighted tokens.
691
+ df_scores : pandas.DataFrame or None
692
+ DataFrame of per‑token statistics. None if no input image.
693
+ txt_file_path : str or None
694
+ Path to a temporary .txt file containing the plain predicted text.
695
+ csv_file_path : str or None
696
+ Path to a temporary CSV file containing the extended token scores.
697
+ """
698
  if image is None:
699
  return None, "", None, None, None
700
  # Convert the numpy array to a PIL image
 
807
  if group.iloc[-1]["token"].strip() == "":
808
  to_drop.append(group.index[-1])
809
  df_all = df_all.drop(index=to_drop)
810
+ # Inject style definitions for token colouring and tooltips
811
+ style_block = """
812
+ <style>
813
+ .ocr-token { position: relative; cursor: pointer; padding: 0 2px; border-radius: 2px; }
814
+ .ocr-token.token-red { background-color: rgba(255, 107, 107, 0.7); }
815
+ .ocr-token.token-yellow { background-color: rgba(255, 217, 59, 0.7); }
816
+ .ocr-token.token-green { background-color: rgba(107, 207, 99, 0.7); }
817
+ .ocr-token:hover::after {
818
+ content: attr(data-tooltip);
819
+ position: absolute;
820
+ bottom: 120%;
821
+ left: 0;
822
+ white-space: pre-line;
823
+ background: #333;
824
+ color: #fff;
825
+ padding: 6px 10px;
826
+ border-radius: 6px;
827
+ font-size: 12px;
828
+ z-index: 999;
829
+ max-width: 220px;
830
+ box-shadow: 0 4px 12px rgba(0,0,0,0.15);
831
+ }
832
+ </style>
833
+ """
834
+ if predicted_html:
835
+ predicted_html = style_block + predicted_html
836
  # Prepare plain text by stripping HTML tags and replacing <br>
837
  plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n"))
838
  # Write temporary files
 
854
  csv_fd.close()
855
  except Exception:
856
  csv_path = None
857
+ return overlay_img, predicted_html, df_all, txt_path, csv_path
 
 
 
858
 
859
 
860
  # ----------------------------------------------------------------------
 
863
  def create_gradio_interface():
864
  """Create and return the Gradio Blocks interface."""
865
  with gr.Blocks(title="Old Nepali HTR") as demo:
866
+ gr.Markdown("""# Old Nepali HTR (Gradio)\n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
867
  with gr.Row():
868
  image_input = gr.Image(type="numpy", label="Upload Image")
869
  # When used as an input, gr.File returns either a file path or bytes
 
889
  predictions_output = gr.HTML(label="Predictions (HTML)")
890
  df_output = gr.DataFrame(label="Token Scores", interactive=False)
891
  # Separate file outputs for the OCR prediction, token scores and edited text.
892
+ ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)")
893
+ ocr_csv_output = gr.File(label="Download Token Scores (.csv)")
 
 
 
894
  edited_txt_output = gr.File(label="Download edited text (.txt)")
895
 
896
  # Editable text area
 
941
  )
942
  return demo
943
 
 
 
944
  iface = create_gradio_interface()
945
  iface.launch()