AnjaliSarawgi commited on
Commit
8cb1f4a
·
1 Parent(s): 68afe16

Add application file

Browse files
Files changed (1) hide show
  1. app.py +11 -242
app.py CHANGED
@@ -1,27 +1,3 @@
1
- """
2
- Gradio application for performing OCR on scanned Old Nepali documents.
3
-
4
- This script is a Gradio port of a Streamlit application originally built
5
- to visualize and edit OCR output. It loads a pre‑trained model for
6
- sequence decoding, accepts an input image (and optional segmentation
7
- XML in ALTO format), performs OCR on segmented lines, highlights tokens
8
- with low confidence and offers downloads of both the raw text and per
9
- token scores.
10
-
11
- The heavy lifting functions (model loading, pre‑processing, inference
12
- and highlighting) are adapted directly from the Streamlit version. The
13
- UI has been simplified for Gradio: users upload an image and optional
14
- XML file, choose preprocessing steps and a highlight metric, then run
15
- OCR. The results are displayed alongside the overlaid segmentation
16
- boxes and a table of token scores. An editable textbox lets users
17
- modify the predicted text before downloading it.
18
-
19
- To run this app locally, install gradio (`pip install gradio`) and
20
- execute this script with Python:
21
-
22
- python gradio_app.py
23
-
24
- """
25
 
26
  import io
27
  import os
@@ -46,68 +22,18 @@ from transformers import (
46
  from matplotlib import cm
47
  import gradio as gr
48
 
49
- # ----------------------------------------------------------------------
50
- # Configuration
51
- #
52
- # These constants control various aspects of the OCR pipeline. You can
53
- # adjust them to trade off accuracy, performance or output volume.
54
-
55
- # The maximum number of tokens to decode for a single line. If your
56
- # documents typically have longer lines you can increase this value, but
57
- # beware that very long sequences may cause more memory usage.
58
  MAX_LEN: int = 128
59
-
60
- # How many alternative tokens to keep when computing per–token statistics.
61
  TOPK: int = 3
62
-
63
- # If an XML segmentation file is provided, only process the first
64
- # MAX_LINES lines. This prevents huge documents from consuming
65
- # excessive resources.
66
  MAX_LINES: int = 120
67
-
68
- # Images are resized such that the longest side does not exceed this
69
- # number of pixels before passing them to the OCR model. Increasing
70
- # this value may improve accuracy at the cost of speed and memory.
71
  RESIZE_MAX_SIDE: int = 800
72
-
73
- # Threshold used when highlighting tokens by relative probability. A
74
- # ratio of Top2/Top1 greater than this value will cause the token to
75
- # be highlighted in red.
76
  REL_PROB_TH: float = 0.70
77
-
78
- # A regex used to clean up Unicode control characters before text
79
- # normalization. Soft hyphens, zero width spaces and similar marks
80
- # interfere with accurate token matching.
81
  CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")
82
-
83
- # Default font path for rendering predictions directly on the image.
84
  FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
85
 
86
 
87
- # ----------------------------------------------------------------------
88
- # Model loading
89
- #
90
- # Loading the model and associated tokenizer/processor is slow. Use
91
- # functools.lru_cache to ensure this only happens once per process.
92
-
93
  @lru_cache(maxsize=1)
94
  def load_model():
95
- """Load the OCR model, tokenizer and feature extractor.
96
-
97
- Returns
98
- -------
99
- model : VisionEncoderDecoderModel
100
- The loaded model in evaluation mode.
101
- tokenizer : PreTrainedTokenizerFast
102
- Tokenizer corresponding to the decoder part of the model.
103
- feature_extractor : callable
104
- Feature extractor converting PIL images into model inputs.
105
- device : torch.device
106
- The device (CPU or CUDA) used for inference.
107
- """
108
  model_path = "AnjaliSarawgi/model-oct"
109
- # In an offline environment the HF token is None; if you wish
110
- # to use a private model you can set HF_TOKEN in your environment.
111
  hf_token = os.environ.get("HF_TOKEN")
112
  model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
113
  tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
@@ -117,45 +43,13 @@ def load_model():
117
  return model, tokenizer, processor.feature_extractor, device
118
 
119
 
120
- # ----------------------------------------------------------------------
121
- # Utility functions
122
- #
123
-
124
  def clean_text(text: str) -> str:
125
- """Normalize and collapse whitespace from a decoded string.
126
-
127
- Parameters
128
- ----------
129
- text : str
130
- The raw decoded string from the model.
131
-
132
- Returns
133
- -------
134
- str
135
- The cleaned string with Unicode normalization and whitespace
136
- removed. All whitespace characters are stripped since the
137
- predictions are later tokenized at the akshara (syllable) level.
138
- """
139
  text = unicodedata.normalize("NFC", text)
140
  text = CLEANUP.sub("", text)
141
  return re.sub(r"\s+", "", text)
142
 
143
 
144
  def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
145
- """Resize the image so that its longest side equals max_side.
146
-
147
- Parameters
148
- ----------
149
- image : PIL.Image
150
- Input image.
151
- max_side : int, optional
152
- Maximum allowed size for the longest side of the image.
153
-
154
- Returns
155
- -------
156
- PIL.Image
157
- The resized image.
158
- """
159
  img = image.convert("RGB")
160
  w, h = img.size
161
  if max(w, h) > max_side:
@@ -164,36 +58,10 @@ def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.
164
 
165
 
166
  def get_amp_ctx():
167
- """Return the appropriate context manager for automatic mixed precision."""
168
  return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext
169
 
170
-
171
- # ----------------------------------------------------------------------
172
- # XML parsing and segmentation
173
  #
174
  def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
175
- """Parse ALTO or PAGE XML to extract bounding boxes.
176
-
177
- Parameters
178
- ----------
179
- xml_bytes : bytes
180
- Raw XML bytes.
181
- level : {"block", "line", "word"}, optional
182
- The segmentation level to extract. For OCR we use "line".
183
- image_size : tuple or None
184
- If provided, image_size=(width, height) allows rescaling
185
- coordinates to match the actual image. ALTO files often store
186
- absolute page sizes that differ from the image dimensions.
187
-
188
- Returns
189
- -------
190
- list of dict
191
- Each dict represents a bounding box with keys:
192
- - "bbox": [x1, y1, x2, y2]
193
- - "points": list of (x, y) if polygonal coordinates exist
194
- - "id": line identifier (string)
195
- - "label": the type of element (e.g. TextLine)
196
- """
197
  def _strip_ns(elem):
198
  for e in elem.iter():
199
  if isinstance(e.tag, str) and e.tag.startswith("{"):
@@ -323,7 +191,6 @@ def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tupl
323
 
324
 
325
  def sort_boxes_reading_order(boxes, y_tol: int = 10):
326
- """Sort bounding boxes top‑to‑bottom then left‑to‑right."""
327
  def key(b):
328
  x1, y1, x2, y2 = b["bbox"]
329
  return (round(y1 / max(1, y_tol)), y1, x1)
@@ -331,21 +198,6 @@ def sort_boxes_reading_order(boxes, y_tol: int = 10):
331
 
332
 
333
  def draw_boxes(img: Image.Image, boxes):
334
- """Overlay semi‑transparent red polygons or rectangles on an image.
335
-
336
- Parameters
337
- ----------
338
- img : PIL.Image
339
- The base image.
340
- boxes : list of dict
341
- Segmentation boxes with either 'points' or 'bbox' keys.
342
-
343
- Returns
344
- -------
345
- PIL.Image
346
- An image with red overlays marking each box. Boxes are numbered
347
- starting from 1.
348
- """
349
  base = img.convert("RGBA")
350
  overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
351
  draw = ImageDraw.Draw(overlay)
@@ -366,33 +218,7 @@ def draw_boxes(img: Image.Image, boxes):
366
  return Image.alpha_composite(base, overlay).convert("RGB")
367
 
368
 
369
- # ----------------------------------------------------------------------
370
- # OCR inference per line
371
- #
372
  def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
373
- """Run the model on a single cropped line and return predictions and scores.
374
-
375
- This helper wraps the model.generate call to obtain per‑token
376
- probabilities and derives a DataFrame summarizing each decoding step.
377
-
378
- Parameters
379
- ----------
380
- image : PIL.Image
381
- Cropped segment to process.
382
- line_id : int, optional
383
- Identifier used in the output DataFrame.
384
- topk : int, optional
385
- Number of alternative tokens to keep for each decoding position.
386
-
387
- Returns
388
- -------
389
- decoded_text : str
390
- Cleaned predicted string for the line.
391
- df : pandas.DataFrame
392
- Table with one row per generated token containing the following
393
- columns: line_id, seq_pos, token_id, token, confidence,
394
- rel_prob, entropy, gap12, alt_tokens, alt_probs.
395
- """
396
  model, tokenizer, feature_extractor, device = load_model()
397
  img = prepare_image(image)
398
  pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
@@ -479,13 +305,6 @@ def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOP
479
  return decoded_text, df
480
 
481
 
482
- # ----------------------------------------------------------------------
483
- # Text splitting into aksharas (syllable units) for highlighting
484
- #
485
- # The following regex and helper functions split a Devanagari string into
486
- # aksharas. This is necessary to map model tokens back to spans of
487
- # characters when highlighting uncertain predictions.
488
-
489
  DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range
490
  INDEP_VOW = "\u0904-\u0914" # independent vowels
491
  NUKTA = "\u093C" # nukta
@@ -533,29 +352,6 @@ def parse_alt_tokens(s: str):
533
  def highlight_tokens_with_tooltips(
534
  line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
535
  ) -> str:
536
- """Insert HTML spans around tokens whose chosen metric exceeds threshold.
537
-
538
- The metric column can be "rel_prob" (relative probability) or
539
- "entropy". Tokens with a value strictly greater than red_threshold
540
- will be wrapped in a span with a tooltip listing alternative
541
- predictions and their probabilities.
542
-
543
- Parameters
544
- ----------
545
- line_text : str
546
- The cleaned line prediction.
547
- df_tok : pandas.DataFrame
548
- DataFrame of token statistics for the corresponding line.
549
- red_threshold : float
550
- Values above this threshold will be highlighted.
551
- metric_column : str
552
- Column name in df_tok used for thresholding.
553
-
554
- Returns
555
- -------
556
- str
557
- An HTML string with <span> elements inserted.
558
- """
559
  aks, spans = split_aksharas(line_text)
560
  joined = "".join(aks)
561
  used_ranges = []
@@ -639,40 +435,7 @@ def run_ocr(
639
  apply_bin: bool,
640
  highlight_metric: str,
641
  ):
642
- """Run the OCR pipeline on user inputs and return results for Gradio.
643
-
644
- Parameters
645
- ----------
646
- image : numpy.ndarray or None
647
- The uploaded image converted to a NumPy array by Gradio. If
648
- None, the function returns empty results.
649
- xml_file : tuple or None
650
- A tuple representing the uploaded XML file as provided by
651
- gr.File. The first element is the file name and the second is
652
- bytes. If None, no segmentation is applied and the entire
653
- image is processed as a single line.
654
- apply_gray : bool
655
- Whether to convert the image to grayscale before OCR.
656
- apply_bin : bool
657
- Whether to apply binarization (Otsu threshold) before OCR. If
658
- selected, grayscale conversion is applied first automatically.
659
- highlight_metric : str
660
- Which metric to use for highlighting ("Relative Probability" or
661
- "Entropy").
662
-
663
- Returns
664
- -------
665
- overlay_img : PIL.Image or None
666
- Image with segmentation boxes drawn. None if no input image.
667
- predictions_html : str
668
- HTML formatted predicted text with highlighted tokens.
669
- df_scores : pandas.DataFrame or None
670
- DataFrame of per‑token statistics. None if no input image.
671
- txt_file_path : str or None
672
- Path to a temporary .txt file containing the plain predicted text.
673
- csv_file_path : str or None
674
- Path to a temporary CSV file containing the extended token scores.
675
- """
676
  if image is None:
677
  return None, "", None, None, None
678
  # Convert the numpy array to a PIL image
@@ -806,7 +569,10 @@ def run_ocr(
806
  csv_fd.close()
807
  except Exception:
808
  csv_path = None
809
- return overlay_img, predicted_html, df_all, txt_path, csv_path
 
 
 
810
 
811
 
812
  # ----------------------------------------------------------------------
@@ -815,7 +581,7 @@ def run_ocr(
815
  def create_gradio_interface():
816
  """Create and return the Gradio Blocks interface."""
817
  with gr.Blocks(title="Old Nepali HTR") as demo:
818
- gr.Markdown("""# Old Nepali HTR (Gradio)\n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
819
  with gr.Row():
820
  image_input = gr.Image(type="numpy", label="Upload Image")
821
  # When used as an input, gr.File returns either a file path or bytes
@@ -841,8 +607,11 @@ def create_gradio_interface():
841
  predictions_output = gr.HTML(label="Predictions (HTML)")
842
  df_output = gr.DataFrame(label="Token Scores", interactive=False)
843
  # Separate file outputs for the OCR prediction, token scores and edited text.
844
- ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)")
845
- ocr_csv_output = gr.File(label="Download Token Scores (.csv)")
 
 
 
846
  edited_txt_output = gr.File(label="Download edited text (.txt)")
847
 
848
  # Editable text area
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  import io
3
  import os
 
22
  from matplotlib import cm
23
  import gradio as gr
24
 
 
 
 
 
 
 
 
 
 
25
  MAX_LEN: int = 128
 
 
26
  TOPK: int = 3
 
 
 
 
27
  MAX_LINES: int = 120
 
 
 
 
28
  RESIZE_MAX_SIDE: int = 800
 
 
 
 
29
  REL_PROB_TH: float = 0.70
 
 
 
 
30
  CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")
 
 
31
  FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
32
 
33
 
 
 
 
 
 
 
34
  @lru_cache(maxsize=1)
35
  def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  model_path = "AnjaliSarawgi/model-oct"
 
 
37
  hf_token = os.environ.get("HF_TOKEN")
38
  model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
39
  tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
 
43
  return model, tokenizer, processor.feature_extractor, device
44
 
45
 
 
 
 
 
46
  def clean_text(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  text = unicodedata.normalize("NFC", text)
48
  text = CLEANUP.sub("", text)
49
  return re.sub(r"\s+", "", text)
50
 
51
 
52
  def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  img = image.convert("RGB")
54
  w, h = img.size
55
  if max(w, h) > max_side:
 
58
 
59
 
60
  def get_amp_ctx():
 
61
  return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext
62
 
 
 
 
63
  #
64
  def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def _strip_ns(elem):
66
  for e in elem.iter():
67
  if isinstance(e.tag, str) and e.tag.startswith("{"):
 
191
 
192
 
193
  def sort_boxes_reading_order(boxes, y_tol: int = 10):
 
194
  def key(b):
195
  x1, y1, x2, y2 = b["bbox"]
196
  return (round(y1 / max(1, y_tol)), y1, x1)
 
198
 
199
 
200
  def draw_boxes(img: Image.Image, boxes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  base = img.convert("RGBA")
202
  overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
203
  draw = ImageDraw.Draw(overlay)
 
218
  return Image.alpha_composite(base, overlay).convert("RGB")
219
 
220
 
 
 
 
221
  def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  model, tokenizer, feature_extractor, device = load_model()
223
  img = prepare_image(image)
224
  pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
 
305
  return decoded_text, df
306
 
307
 
 
 
 
 
 
 
 
308
  DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range
309
  INDEP_VOW = "\u0904-\u0914" # independent vowels
310
  NUKTA = "\u093C" # nukta
 
352
  def highlight_tokens_with_tooltips(
353
  line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
354
  ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  aks, spans = split_aksharas(line_text)
356
  joined = "".join(aks)
357
  used_ranges = []
 
435
  apply_bin: bool,
436
  highlight_metric: str,
437
  ):
438
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  if image is None:
440
  return None, "", None, None, None
441
  # Convert the numpy array to a PIL image
 
569
  csv_fd.close()
570
  except Exception:
571
  csv_path = None
572
+ # return overlay_img, predicted_html, df_all, txt_path, csv_path
573
+ txt_bytes = plain_text.encode("utf-8")
574
+ csv_bytes = df_all.to_csv(index=False).encode("utf-8")
575
+ return overlay_img, predicted_html, df_all, txt_bytes, csv_bytes
576
 
577
 
578
  # ----------------------------------------------------------------------
 
581
  def create_gradio_interface():
582
  """Create and return the Gradio Blocks interface."""
583
  with gr.Blocks(title="Old Nepali HTR") as demo:
584
+ gr.Markdown("""# Old Nepali HTR \n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
585
  with gr.Row():
586
  image_input = gr.Image(type="numpy", label="Upload Image")
587
  # When used as an input, gr.File returns either a file path or bytes
 
607
  predictions_output = gr.HTML(label="Predictions (HTML)")
608
  df_output = gr.DataFrame(label="Token Scores", interactive=False)
609
  # Separate file outputs for the OCR prediction, token scores and edited text.
610
+ # ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)")
611
+ # ocr_csv_output = gr.File(label="Download Token Scores (.csv)")
612
+ ocr_txt_output = gr.File(label="Download OCR Prediction (.txt)", type="binary")
613
+ ocr_csv_output = gr.File(label="Download Token Scores (.csv)", type="binary")
614
+
615
  edited_txt_output = gr.File(label="Download edited text (.txt)")
616
 
617
  # Editable text area