Gabriel commited on
Commit
7263d32
1 Parent(s): d50ab3a

refactored the pipeline

Browse files
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from helper.gradio_config import css, js, theme
4
  from helper.text.text_about import TextAbout
5
  from helper.text.text_app import TextApp
6
  from helper.text.text_howto import TextHowTo
@@ -21,7 +21,7 @@ with gr.Blocks(title="HTR Riksarkivet", theme=theme, css=css) as demo:
21
  with gr.Tab("How to use"):
22
  with gr.Tabs():
23
  with gr.Tab("HTR Tool"):
24
- with gr.Row().style(equal_height=False):
25
  with gr.Column():
26
  gr.Markdown(TextHowTo.htr_tool)
27
  with gr.Column():
@@ -33,7 +33,7 @@ with gr.Blocks(title="HTR Riksarkivet", theme=theme, css=css) as demo:
33
  gr.Markdown(TextHowTo.reach_out)
34
 
35
  with gr.Tab("Stepwise HTR Tool"):
36
- with gr.Row().style(equal_height=False):
37
  gr.Markdown(TextHowTo.stepwise_htr_tool)
38
  with gr.Row():
39
  gr.Markdown(TextHowTo.stepwise_htr_tool_tab_intro)
@@ -115,7 +115,7 @@ print(job.result())
115
  with gr.Column():
116
  gr.Markdown(TextRoadmap.discussion)
117
 
118
- demo.load(None, None, None, _js=js)
119
 
120
 
121
  demo.queue(concurrency_count=1, max_size=1)
 
1
  import gradio as gr
2
 
3
+ from helper.gradio_config import css, theme
4
  from helper.text.text_about import TextAbout
5
  from helper.text.text_app import TextApp
6
  from helper.text.text_howto import TextHowTo
 
21
  with gr.Tab("How to use"):
22
  with gr.Tabs():
23
  with gr.Tab("HTR Tool"):
24
+ with gr.Row(equal_height=False):
25
  with gr.Column():
26
  gr.Markdown(TextHowTo.htr_tool)
27
  with gr.Column():
 
33
  gr.Markdown(TextHowTo.reach_out)
34
 
35
  with gr.Tab("Stepwise HTR Tool"):
36
+ with gr.Row(equal_height=False):
37
  gr.Markdown(TextHowTo.stepwise_htr_tool)
38
  with gr.Row():
39
  gr.Markdown(TextHowTo.stepwise_htr_tool_tab_intro)
 
115
  with gr.Column():
116
  gr.Markdown(TextRoadmap.discussion)
117
 
118
+ # demo.load(None, None, None, _js=js)
119
 
120
 
121
  demo.queue(concurrency_count=1, max_size=1)
helper/gradio_config.py CHANGED
@@ -21,6 +21,9 @@ class GradioConfig:
21
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 450px}
22
  #gallery {height: 400px}
23
  .fixed-height.svelte-g4rw9.svelte-g4rw9 {min-height: 400px;}
 
 
 
24
  """
25
 
26
  def generate_tooltip_css(self):
 
21
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 450px}
22
  #gallery {height: 400px}
23
  .fixed-height.svelte-g4rw9.svelte-g4rw9 {min-height: 400px;}
24
+
25
+ #gallery_lines > div.preview.svelte-1b19cri > div.thumbnails.scroll-hide.svelte-1b19cri {display: none;}
26
+
27
  """
28
 
29
  def generate_tooltip_css(self):
requirements.txt CHANGED
@@ -14,6 +14,8 @@ pillow==9.5.0
14
 
15
 
16
 
 
 
17
  # make install_openmmlab (they are excuted in dockerfile)
18
  # !pip install -U openmim
19
  # !mim install mmengine
 
14
 
15
 
16
 
17
+
18
+
19
  # make install_openmmlab (they are excuted in dockerfile)
20
  # !pip install -U openmim
21
  # !mim install mmengine
src/htr_pipeline/gradio_backend.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
 
6
  from src.htr_pipeline.inferencer import Inferencer, InferencerInterface
7
  from src.htr_pipeline.pipeline import Pipeline, PipelineInterface
 
8
 
9
 
10
  class SingletonModelLoader:
@@ -28,6 +29,7 @@ class FastTrack:
28
  self.pipeline: PipelineInterface = model_loader.pipeline
29
 
30
  def segment_to_xml(self, image, radio_button_choices):
 
31
  xml_xml = "page_xml.xml"
32
  xml_txt = "page_txt.txt"
33
 
@@ -40,6 +42,11 @@ class FastTrack:
40
  f.write(rendered_xml)
41
 
42
  xml_img = self.visualize_xml_and_return_txt(image, xml_txt)
 
 
 
 
 
43
  if len(radio_button_choices) < 2:
44
  if radio_button_choices[0] == "Txt":
45
  returned_file_extension = xml_txt
@@ -47,8 +54,7 @@ class FastTrack:
47
  returned_file_extension = xml_xml
48
  else:
49
  returned_file_extension = [xml_txt, xml_xml]
50
-
51
- return xml_img, returned_file_extension, gr.update(visible=True)
52
 
53
  def segment_to_xml_api(self, image):
54
  rendered_xml = self.pipeline.running_htr_pipeline(image)
@@ -70,12 +76,14 @@ class CustomTrack:
70
  def __init__(self, model_loader):
71
  self.inferencer: InferencerInterface = model_loader.inferencer
72
 
 
73
  def region_segment(self, image, pred_score_threshold, containments_treshold):
74
  predicted_regions, regions_cropped_ordered, _, _ = self.inferencer.predict_regions(
75
  image, pred_score_threshold, containments_treshold
76
  )
77
  return predicted_regions, regions_cropped_ordered, gr.update(visible=False), gr.update(visible=True)
78
 
 
79
  def line_segment(self, image, pred_score_threshold, containments_threshold):
80
  predicted_lines, lines_cropped_ordered, _ = self.inferencer.predict_lines(
81
  image, pred_score_threshold, containments_threshold
@@ -93,22 +101,35 @@ class CustomTrack:
93
  )
94
 
95
  def transcribe_text(self, df, images):
 
96
  transcription_temp_list_with_score = []
97
  mapping_dict = {}
98
 
 
 
 
 
 
 
99
  for image in images:
 
 
 
 
 
 
100
  transcribed_text, prediction_score_from_htr = self.inferencer.transcribe(image)
101
  transcription_temp_list_with_score.append((transcribed_text, prediction_score_from_htr))
102
 
103
  df_trans_explore = pd.DataFrame(
104
- transcription_temp_list_with_score, columns=["Transcribed text", "HTR prediction score"]
105
  )
106
 
107
  mapping_dict[transcribed_text] = image
108
 
109
- yield df_trans_explore[["Transcribed text"]], df_trans_explore, mapping_dict, gr.update(
110
- visible=False
111
- ), gr.update(visible=True), gr.update(visible=False)
112
 
113
  def get_select_index_image(self, images_from_gallery, evt: gr.SelectData):
114
  return images_from_gallery[evt.index]["name"]
@@ -120,7 +141,7 @@ class CustomTrack:
120
  new_first = [sorted_image]
121
  new_list = [img for txt, img in mapping_dict.items() if txt != key_text]
122
  new_first.extend(new_list)
123
- return new_first
124
 
125
  def download_df_to_txt(self, transcribed_df):
126
  text_in_list = transcribed_df["Transcribed text"].tolist()
 
5
 
6
  from src.htr_pipeline.inferencer import Inferencer, InferencerInterface
7
  from src.htr_pipeline.pipeline import Pipeline, PipelineInterface
8
+ from src.htr_pipeline.utils.helper import gradio_info
9
 
10
 
11
  class SingletonModelLoader:
 
29
  self.pipeline: PipelineInterface = model_loader.pipeline
30
 
31
  def segment_to_xml(self, image, radio_button_choices):
32
+ gr.Info("Running HTR-pipeline")
33
  xml_xml = "page_xml.xml"
34
  xml_txt = "page_txt.txt"
35
 
 
42
  f.write(rendered_xml)
43
 
44
  xml_img = self.visualize_xml_and_return_txt(image, xml_txt)
45
+ returned_file_extension = self.file_extenstion_to_return(radio_button_choices, xml_xml, xml_txt)
46
+
47
+ return xml_img, returned_file_extension, gr.update(visible=True)
48
+
49
+ def file_extenstion_to_return(self, radio_button_choices, xml_xml, xml_txt):
50
  if len(radio_button_choices) < 2:
51
  if radio_button_choices[0] == "Txt":
52
  returned_file_extension = xml_txt
 
54
  returned_file_extension = xml_xml
55
  else:
56
  returned_file_extension = [xml_txt, xml_xml]
57
+ return returned_file_extension
 
58
 
59
  def segment_to_xml_api(self, image):
60
  rendered_xml = self.pipeline.running_htr_pipeline(image)
 
76
  def __init__(self, model_loader):
77
  self.inferencer: InferencerInterface = model_loader.inferencer
78
 
79
+ @gradio_info("Running Segment Region")
80
  def region_segment(self, image, pred_score_threshold, containments_treshold):
81
  predicted_regions, regions_cropped_ordered, _, _ = self.inferencer.predict_regions(
82
  image, pred_score_threshold, containments_treshold
83
  )
84
  return predicted_regions, regions_cropped_ordered, gr.update(visible=False), gr.update(visible=True)
85
 
86
+ @gradio_info("Running Segment Line")
87
  def line_segment(self, image, pred_score_threshold, containments_threshold):
88
  predicted_lines, lines_cropped_ordered, _ = self.inferencer.predict_lines(
89
  image, pred_score_threshold, containments_threshold
 
101
  )
102
 
103
  def transcribe_text(self, df, images):
104
+ gr.Info("Running Transcribe Lines")
105
  transcription_temp_list_with_score = []
106
  mapping_dict = {}
107
 
108
+ total_images = len(images)
109
+ current_index = 0
110
+
111
+ bool_to_show_placeholder = gr.update(visible=True)
112
+ bool_to_show_control_results_transcribe = gr.update(visible=False)
113
+
114
  for image in images:
115
+ current_index += 1
116
+
117
+ if current_index == total_images:
118
+ bool_to_show_control_results_transcribe = gr.update(visible=True)
119
+ bool_to_show_placeholder = gr.update(visible=False)
120
+
121
  transcribed_text, prediction_score_from_htr = self.inferencer.transcribe(image)
122
  transcription_temp_list_with_score.append((transcribed_text, prediction_score_from_htr))
123
 
124
  df_trans_explore = pd.DataFrame(
125
+ transcription_temp_list_with_score, columns=["Transcribed text", "Pred score"]
126
  )
127
 
128
  mapping_dict[transcribed_text] = image
129
 
130
+ yield df_trans_explore[
131
+ ["Transcribed text"]
132
+ ], df_trans_explore, mapping_dict, bool_to_show_control_results_transcribe, bool_to_show_placeholder
133
 
134
  def get_select_index_image(self, images_from_gallery, evt: gr.SelectData):
135
  return images_from_gallery[evt.index]["name"]
 
141
  new_first = [sorted_image]
142
  new_list = [img for txt, img in mapping_dict.items() if txt != key_text]
143
  new_first.extend(new_list)
144
+ return new_first, key_text
145
 
146
  def download_df_to_txt(self, transcribed_df):
147
  text_in_list = transcribed_df["Transcribed text"].tolist()
src/htr_pipeline/pipeline.py CHANGED
@@ -6,15 +6,18 @@ import numpy as np
6
  from src.htr_pipeline.inferencer import Inferencer
7
  from src.htr_pipeline.utils.helper import timer_func
8
  from src.htr_pipeline.utils.parser_xml import XmlParser
 
9
  from src.htr_pipeline.utils.preprocess_img import Preprocess
10
- from src.htr_pipeline.utils.process_xml import XMLHelper
 
 
11
 
12
 
13
  class Pipeline:
14
  def __init__(self, inferencer: Inferencer) -> None:
15
  self.inferencer = inferencer
16
- self.xml = XMLHelper()
17
  self.preprocess_img = Preprocess()
 
18
 
19
  @timer_func
20
  def running_htr_pipeline(
@@ -27,7 +30,7 @@ class Pipeline:
27
  input_image = self.preprocess_img.binarize_img(input_image)
28
  image = mmcv.imread(input_image)
29
 
30
- rendered_xml = self.xml.image_to_page_xml(
31
  image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, self.inferencer
32
  )
33
 
@@ -35,14 +38,15 @@ class Pipeline:
35
 
36
  @timer_func
37
  def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
38
- self.xml_visualizer_and_parser = XmlParser()
39
  bin_input_image = self.preprocess_img.binarize_img(input_image)
40
- xml_image = self.xml_visualizer_and_parser.visualize_xml(bin_input_image)
41
  return xml_image
42
 
43
  @timer_func
44
  def parse_xml_to_txt(self) -> None:
45
- self.xml_visualizer_and_parser.xml_to_txt()
 
46
 
47
 
48
  class PipelineInterface(Protocol):
 
6
  from src.htr_pipeline.inferencer import Inferencer
7
  from src.htr_pipeline.utils.helper import timer_func
8
  from src.htr_pipeline.utils.parser_xml import XmlParser
9
+ from src.htr_pipeline.utils.pipeline_inferencer import PipelineInferencer, XMLHelper
10
  from src.htr_pipeline.utils.preprocess_img import Preprocess
11
+ from src.htr_pipeline.utils.process_segmask import SegMaskHelper
12
+ from src.htr_pipeline.utils.visualize_xml import XmlViz
13
+ from src.htr_pipeline.utils.xml_helper import XMLHelper
14
 
15
 
16
  class Pipeline:
17
  def __init__(self, inferencer: Inferencer) -> None:
18
  self.inferencer = inferencer
 
19
  self.preprocess_img = Preprocess()
20
+ self.pipeline_inferencer = PipelineInferencer(SegMaskHelper(), XMLHelper())
21
 
22
  @timer_func
23
  def running_htr_pipeline(
 
30
  input_image = self.preprocess_img.binarize_img(input_image)
31
  image = mmcv.imread(input_image)
32
 
33
+ rendered_xml = self.pipeline_inferencer.image_to_page_xml(
34
  image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, self.inferencer
35
  )
36
 
 
38
 
39
  @timer_func
40
  def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
41
+ xml_viz = XmlViz()
42
  bin_input_image = self.preprocess_img.binarize_img(input_image)
43
+ xml_image = xml_viz.visualize_xml(bin_input_image)
44
  return xml_image
45
 
46
  @timer_func
47
  def parse_xml_to_txt(self) -> None:
48
+ xml_visualizer_and_parser = XmlParser()
49
+ xml_visualizer_and_parser.xml_to_txt()
50
 
51
 
52
  class PipelineInterface(Protocol):
src/htr_pipeline/utils/helper.py CHANGED
@@ -1,7 +1,9 @@
1
  import functools
2
  import threading
3
  import time
 
4
 
 
5
  import tqdm
6
 
7
 
@@ -75,6 +77,19 @@ def another_long_running_function(*args, **kwargs):
75
  return "success"
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if __name__ == "__main__":
79
  # Basic example
80
  retval = provide_progress_bar(long_running_function, estimated_time=5)
 
1
  import functools
2
  import threading
3
  import time
4
+ from functools import wraps
5
 
6
+ import gradio as gr
7
  import tqdm
8
 
9
 
 
77
  return "success"
78
 
79
 
80
+ # Decorator for logging
81
+ def gradio_info(message):
82
+ def decorator(func):
83
+ @wraps(func)
84
+ def wrapper(*args, **kwargs):
85
+ gr.Info(message)
86
+ return func(*args, **kwargs)
87
+
88
+ return wrapper
89
+
90
+ return decorator
91
+
92
+
93
  if __name__ == "__main__":
94
  # Basic example
95
  retval = provide_progress_bar(long_running_function, estimated_time=5)
src/htr_pipeline/utils/parser_xml.py CHANGED
@@ -1,10 +1,5 @@
1
- import math
2
- import os
3
- import random
4
  import xml.etree.ElementTree as ET
5
 
6
- from PIL import Image, ImageDraw, ImageFont
7
-
8
 
9
  class XmlParser:
10
  def __init__(self, page_xml="./page_xml.xml"):
@@ -12,61 +7,6 @@ class XmlParser:
12
  self.root = self.tree.getroot()
13
  self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
14
 
15
- def visualize_xml(
16
- self,
17
- background_image,
18
- font_size=9,
19
- text_offset=10,
20
- font_path_tff="./src/htr_pipeline/utils/templates/arial.ttf",
21
- ):
22
- image = Image.fromarray(background_image).convert("RGBA")
23
- image_width = int(self.root.find(f"{self.namespace}Page").attrib["imageWidth"])
24
- image_height = int(self.root.find(f"{self.namespace}Page").attrib["imageHeight"])
25
-
26
- text_offset = -text_offset
27
- base_font_size = font_size
28
- font_path = font_path_tff
29
-
30
- max_bbox_width = 0 # Initialize maximum bounding box width
31
-
32
- for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
33
- coords = textregion.find(f"{self.namespace}Coords").attrib["points"].split()
34
- points = [tuple(map(int, point.split(","))) for point in coords]
35
- x_coords, y_coords = zip(*points)
36
- min_x, max_x = min(x_coords), max(x_coords)
37
- bbox_width = max_x - min_x # Width of the current bounding box
38
- max_bbox_width = max(max_bbox_width, bbox_width) # Update maximum bounding box width
39
-
40
- scaling_factor = max_bbox_width / 400.0 # Use maximum bounding box width for scaling
41
- font_size_scaled = int(base_font_size * scaling_factor)
42
- font = ImageFont.truetype(font_path, font_size_scaled)
43
-
44
- for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
45
- fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 100)
46
- for textline in textregion.findall(f".//{self.namespace}TextLine"):
47
- coords = textline.find(f"{self.namespace}Coords").attrib["points"].split()
48
- points = [tuple(map(int, point.split(","))) for point in coords]
49
-
50
- poly_image = Image.new("RGBA", image.size)
51
- poly_draw = ImageDraw.Draw(poly_image)
52
- poly_draw.polygon(points, fill=fill_color)
53
-
54
- text = textline.find(f"{self.namespace}TextEquiv").find(f"{self.namespace}Unicode").text
55
-
56
- x_coords, y_coords = zip(*points)
57
- min_x, max_x = min(x_coords), max(x_coords)
58
- min_y = min(y_coords)
59
- text_width, text_height = poly_draw.textsize(text, font=font) # Get text size
60
- text_position = (
61
- (min_x + max_x) // 2 - text_width // 2,
62
- min_y + text_offset,
63
- ) # Center text horizontally
64
-
65
- poly_draw.text(text_position, text, fill=(0, 0, 0), font=font)
66
- image = Image.alpha_composite(image, poly_image)
67
-
68
- return image
69
-
70
  def xml_to_txt(self, output_file="page_txt.txt"):
71
  with open(output_file, "w", encoding="utf-8") as f:
72
  for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
 
 
 
 
1
  import xml.etree.ElementTree as ET
2
 
 
 
3
 
4
  class XmlParser:
5
  def __init__(self, page_xml="./page_xml.xml"):
 
7
  self.root = self.tree.getroot()
8
  self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def xml_to_txt(self, output_file="page_txt.txt"):
11
  with open(output_file, "w", encoding="utf-8") as f:
12
  for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
src/htr_pipeline/utils/pipeline_inferencer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+
3
+ from src.htr_pipeline.utils.process_segmask import SegMaskHelper
4
+ from src.htr_pipeline.utils.xml_helper import XMLHelper
5
+
6
+
7
+ class PipelineInferencer:
8
+ def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
9
+ self.process_seg_mask = process_seg_mask
10
+ self.xml_helper = xml_helper
11
+
12
+ def image_to_page_xml(
13
+ self, image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, inferencer
14
+ ):
15
+ template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
16
+ template_data["textRegions"] = self._process_regions(
17
+ image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
18
+ )
19
+
20
+ print(template_data)
21
+ return self.xml_helper.render(template_data)
22
+
23
+ def _process_regions(
24
+ self,
25
+ image,
26
+ inferencer,
27
+ pred_score_threshold_regions,
28
+ pred_score_threshold_lines,
29
+ containments_threshold,
30
+ htr_threshold=0.7,
31
+ ):
32
+ _, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
33
+ image,
34
+ pred_score_threshold=pred_score_threshold_regions,
35
+ containments_threshold=containments_threshold,
36
+ visualize=False,
37
+ )
38
+
39
+ region_data_list = []
40
+ for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
41
+ region_data = self._create_region_data(
42
+ data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
43
+ )
44
+ if region_data:
45
+ region_data_list.append(region_data)
46
+
47
+ return region_data_list
48
+
49
+ def _create_region_data(
50
+ self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
51
+ ):
52
+ text_region, reg_pol, mask = data
53
+ region_data = {"id": f"region_{index}", "boundary": reg_pol}
54
+
55
+ text_lines, htr_scores = self._process_lines(
56
+ text_region,
57
+ inferencer,
58
+ pred_score_threshold_lines,
59
+ containments_threshold,
60
+ mask,
61
+ region_data["id"],
62
+ htr_threshold,
63
+ )
64
+
65
+ if not text_lines:
66
+ return None
67
+
68
+ region_data["textLines"] = text_lines
69
+ mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
70
+
71
+ return region_data if mean_htr_score > htr_threshold else None
72
+
73
+ def _process_lines(
74
+ self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.7
75
+ ):
76
+ _, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
77
+ text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
78
+ )
79
+
80
+ if not lines_cropped_ordered:
81
+ return None, []
82
+
83
+ line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
84
+
85
+ text_lines = []
86
+ htr_scores = []
87
+ for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
88
+ line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
89
+
90
+ if line_data:
91
+ text_lines.append(line_data)
92
+ htr_scores.append(htr_score)
93
+
94
+ return text_lines, htr_scores
95
+
96
+ def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
97
+ line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
98
+
99
+ transcribed_text, htr_score = inferencer.transcribe(line)
100
+ line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
101
+ line_data["pred_score"] = round(htr_score, 4)
102
+
103
+ return line_data if htr_score > htr_threshold else None, htr_score
104
+
105
+
106
+ if __name__ == "__main__":
107
+ pass
src/htr_pipeline/utils/process_xml.py DELETED
@@ -1,167 +0,0 @@
1
- import os
2
- import re
3
- from datetime import datetime
4
-
5
- import jinja2
6
- from tqdm import tqdm
7
-
8
- from src.htr_pipeline.inferencer import InferencerInterface
9
- from src.htr_pipeline.utils.process_segmask import SegMaskHelper
10
-
11
-
12
- class XMLHelper:
13
- def __init__(self):
14
- self.process_seg_mask = SegMaskHelper()
15
-
16
- def image_to_page_xml(
17
- self,
18
- image,
19
- pred_score_threshold_regions,
20
- pred_score_threshold_lines,
21
- containments_threshold,
22
- inferencer: InferencerInterface,
23
- xml_file_name="page_xml.xml",
24
- ):
25
- img_height = image.shape[0]
26
- img_width = image.shape[1]
27
- img_file_name = xml_file_name
28
-
29
- template_data = self.prepare_template_data(img_file_name, img_width, img_height)
30
-
31
- template_data["textRegions"] = self._process_regions(
32
- image,
33
- inferencer,
34
- pred_score_threshold_regions,
35
- pred_score_threshold_lines,
36
- containments_threshold,
37
- )
38
-
39
- rendered_xml = self._render_xml(template_data)
40
-
41
- return rendered_xml
42
-
43
- def _transform_coords(self, input_string):
44
- pattern = r"\[\s*([^\s,]+)\s*,\s*([^\s\]]+)\s*\]"
45
- replacement = r"\1,\2"
46
- return re.sub(pattern, replacement, input_string)
47
-
48
- def _render_xml(self, template_data):
49
- template_loader = jinja2.FileSystemLoader(searchpath="./src/htr_pipeline/utils/templates")
50
- template_env = jinja2.Environment(loader=template_loader, trim_blocks=True)
51
- template = template_env.get_template("page_xml_2013.xml")
52
- rendered_xml = template.render(template_data)
53
- rendered_xml = self._transform_coords(rendered_xml)
54
- return rendered_xml
55
-
56
- def prepare_template_data(self, img_file_name, img_width, img_height):
57
- now = datetime.now()
58
- date_time = now.strftime("%Y-%m-%d, %H:%M:%S")
59
- return {
60
- "created": date_time,
61
- "imageFilename": img_file_name,
62
- "imageWidth": img_width,
63
- "imageHeight": img_height,
64
- "textRegions": list(),
65
- }
66
-
67
- def _process_regions(
68
- self,
69
- image,
70
- inferencer: InferencerInterface,
71
- pred_score_threshold_regions,
72
- pred_score_threshold_lines,
73
- containments_threshold,
74
- htr_threshold=0.7,
75
- ):
76
- _, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
77
- image,
78
- pred_score_threshold=pred_score_threshold_regions,
79
- containments_threshold=containments_threshold,
80
- visualize=False,
81
- )
82
-
83
- region_data_list = []
84
- for i, (text_region, reg_pol, mask) in tqdm(
85
- enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))
86
- ):
87
- region_id = "region_" + str(i)
88
- region_data = dict()
89
- region_data["id"] = region_id
90
- region_data["boundary"] = reg_pol
91
-
92
- text_lines, htr_scores = self._process_lines(
93
- text_region,
94
- inferencer,
95
- pred_score_threshold_lines,
96
- containments_threshold,
97
- mask,
98
- region_id,
99
- )
100
-
101
- if text_lines is None:
102
- continue
103
-
104
- region_data["textLines"] = text_lines
105
- mean_htr_score = sum(htr_scores) / len(htr_scores)
106
-
107
- if mean_htr_score > htr_threshold:
108
- region_data_list.append(region_data)
109
-
110
- return region_data_list
111
-
112
- def _process_lines(
113
- self,
114
- text_region,
115
- inferencer: InferencerInterface,
116
- pred_score_threshold_lines,
117
- containments_threshold,
118
- mask,
119
- region_id,
120
- htr_threshold=0.7,
121
- ):
122
- _, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
123
- text_region,
124
- pred_score_threshold=pred_score_threshold_lines,
125
- containments_threshold=containments_threshold,
126
- visualize=False,
127
- custom_track=False,
128
- )
129
-
130
- if lines_cropped_ordered is None:
131
- return None, None
132
-
133
- line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
134
-
135
- htr_scores = list()
136
- text_lines = list()
137
-
138
- for j, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
139
- line_id = "line_" + region_id + "_" + str(j)
140
- line_data = dict()
141
- line_data["id"] = line_id
142
- line_data["boundary"] = line_pol
143
-
144
- transcribed_text, htr_score = inferencer.transcribe(line)
145
- escaped_text = self._escape_xml_chars(transcribed_text)
146
- line_data["unicode"] = escaped_text
147
- line_data["pred_score"] = round(htr_score, 4)
148
-
149
- htr_scores.append(htr_score)
150
-
151
- if htr_score > htr_threshold:
152
- text_lines.append(line_data)
153
-
154
- return text_lines, htr_scores
155
-
156
- def _escape_xml_chars(self, textline):
157
- return (
158
- textline.replace("&", "&amp;")
159
- .replace("<", "&lt;")
160
- .replace(">", "&gt;")
161
- .replace("'", "&apos;")
162
- .replace('"', "&quot;")
163
- )
164
-
165
-
166
- if __name__ == "__main__":
167
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/htr_pipeline/utils/visualize_xml.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import xml.etree.ElementTree as ET
3
+
4
+ from PIL import Image, ImageDraw, ImageFont
5
+
6
+
7
+ class XmlViz:
8
+ def __init__(self, page_xml="./page_xml.xml"):
9
+ self.tree = ET.parse(page_xml, parser=ET.XMLParser(encoding="utf-8"))
10
+ self.root = self.tree.getroot()
11
+ self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
12
+
13
+ def visualize_xml(
14
+ self,
15
+ background_image,
16
+ font_size=9,
17
+ text_offset=10,
18
+ font_path_tff="./src/htr_pipeline/utils/templates/arial.ttf",
19
+ ):
20
+ image = Image.fromarray(background_image).convert("RGBA")
21
+
22
+ text_offset = -text_offset
23
+ base_font_size = font_size
24
+ font_path = font_path_tff
25
+
26
+ max_bbox_width = 0 # Initialize maximum bounding box width
27
+
28
+ for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
29
+ coords = textregion.find(f"{self.namespace}Coords").attrib["points"].split()
30
+ points = [tuple(map(int, point.split(","))) for point in coords]
31
+ x_coords, y_coords = zip(*points)
32
+ min_x, max_x = min(x_coords), max(x_coords)
33
+ bbox_width = max_x - min_x # Width of the current bounding box
34
+ max_bbox_width = max(max_bbox_width, bbox_width) # Update maximum bounding box width
35
+
36
+ scaling_factor = max_bbox_width / 400.0 # Use maximum bounding box width for scaling
37
+ font_size_scaled = int(base_font_size * scaling_factor)
38
+ font = ImageFont.truetype(font_path, font_size_scaled)
39
+
40
+ for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
41
+ fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 100)
42
+ for textline in textregion.findall(f".//{self.namespace}TextLine"):
43
+ coords = textline.find(f"{self.namespace}Coords").attrib["points"].split()
44
+ points = [tuple(map(int, point.split(","))) for point in coords]
45
+
46
+ poly_image = Image.new("RGBA", image.size)
47
+ poly_draw = ImageDraw.Draw(poly_image)
48
+ poly_draw.polygon(points, fill=fill_color)
49
+
50
+ text = textline.find(f"{self.namespace}TextEquiv").find(f"{self.namespace}Unicode").text
51
+
52
+ x_coords, y_coords = zip(*points)
53
+ min_x, max_x = min(x_coords), max(x_coords)
54
+ min_y = min(y_coords)
55
+ text_width, text_height = poly_draw.textsize(text, font=font) # Get text size
56
+ text_position = (
57
+ (min_x + max_x) // 2 - text_width // 2,
58
+ min_y + text_offset,
59
+ ) # Center text horizontally
60
+
61
+ poly_draw.text(text_position, text, fill=(0, 0, 0), font=font)
62
+ image = Image.alpha_composite(image, poly_image)
63
+
64
+ return image
65
+
66
+
67
+ if __name__ == "__main__":
68
+ pass
src/htr_pipeline/utils/xml_helper.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from datetime import datetime
3
+
4
+ import jinja2
5
+
6
+
7
+ class XMLHelper:
8
+ def __init__(self, xml_file_name="page_xml.xml"):
9
+ self.xml_file_name = xml_file_name
10
+ self.searchpath = "./src/htr_pipeline/utils/templates"
11
+ self.template = "page_xml_2013.xml"
12
+
13
+ def render(self, template_data):
14
+ rendered_xml = self._render_xml(template_data)
15
+ return rendered_xml
16
+
17
+ def _transform_coords(self, input_string):
18
+ pattern = r"\[\s*([^\s,]+)\s*,\s*([^\s\]]+)\s*\]"
19
+ replacement = r"\1,\2"
20
+ return re.sub(pattern, replacement, input_string)
21
+
22
+ def _render_xml(self, template_data):
23
+ template_loader = jinja2.FileSystemLoader(searchpath=self.searchpath)
24
+ template_env = jinja2.Environment(loader=template_loader, trim_blocks=True)
25
+ template = template_env.get_template(self.template)
26
+ rendered_xml = template.render(template_data)
27
+ rendered_xml = self._transform_coords(rendered_xml)
28
+ return rendered_xml
29
+
30
+ def prepare_template_data(self, img_file_name, image):
31
+ img_height = image.shape[0]
32
+ img_width = image.shape[1]
33
+
34
+ now = datetime.now()
35
+ date_time = now.strftime("%Y-%m-%d, %H:%M:%S")
36
+ return {
37
+ "created": date_time,
38
+ "imageFilename": img_file_name,
39
+ "imageWidth": img_width,
40
+ "imageHeight": img_height,
41
+ "textRegions": list(),
42
+ }
43
+
44
+ def escape_xml_chars(self, textline):
45
+ return (
46
+ textline.replace("&", "&amp;")
47
+ .replace("<", "&lt;")
48
+ .replace(">", "&gt;")
49
+ .replace("'", "&apos;")
50
+ .replace('"', "&quot;")
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ pass
tabs/htr_tool.py CHANGED
@@ -19,32 +19,17 @@ with gr.Blocks() as htr_tool_tab:
19
  )
20
 
21
  with gr.Row():
22
- # with gr.Group():
23
- # callback = gr.CSVLogger()
24
- # # hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "htr_pipelin_flags")
25
- # flagging_button = gr.Button(
26
- # "Flag",
27
- # variant="secondary",
28
- # visible=True,
29
- # ).style(full_width=True)
30
- # radio_file_input = gr.Radio(
31
- # value="Text file", choices=["Text file ", "Page XML file "], label="What kind file output?"
32
- # )
33
-
34
  radio_file_input = gr.CheckboxGroup(
35
  choices=["Txt", "XML"],
36
- value=["Txt"],
37
  label="Output file extension",
38
  # info="Only txt and page xml is supported for now!",
 
39
  )
40
 
41
  htr_pipeline_button = gr.Button(
42
- "Run HTR",
43
- variant="primary",
44
- visible=True,
45
- elem_id="run_pipeline_button",
46
- ).style(full_width=False)
47
-
48
  with gr.Group():
49
  with gr.Row():
50
  fast_file_downlod = gr.File(label="Download output file", visible=False)
@@ -75,11 +60,7 @@ with gr.Blocks() as htr_tool_tab:
75
  fast_track_output_image = gr.Image(label="HTR results visualizer", type="numpy", tool="editor", height=650)
76
 
77
  with gr.Row(visible=False) as api_placeholder:
78
- htr_pipeline_button_api = gr.Button(
79
- "Run pipeline",
80
- variant="primary",
81
- visible=False,
82
- ).style(full_width=False)
83
 
84
  xml_rendered_placeholder_for_api = gr.Textbox(visible=False)
85
  htr_pipeline_button.click(
@@ -94,8 +75,3 @@ with gr.Blocks() as htr_tool_tab:
94
  outputs=[xml_rendered_placeholder_for_api],
95
  api_name="predict",
96
  )
97
-
98
- # callback.setup([fast_track_input_region_image], "flagged_data_points")
99
- # flagging_button.click(lambda *args: callback.flag(args), [fast_track_input_region_image], None, preprocess=False)
100
- # flagging_button.click(lambda: (gr.update(value="Flagged")), outputs=flagging_button)
101
- # fast_track_input_region_image.change(lambda: (gr.update(value="Flag")), outputs=flagging_button)
 
19
  )
20
 
21
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
22
  radio_file_input = gr.CheckboxGroup(
23
  choices=["Txt", "XML"],
24
+ value=["XML"],
25
  label="Output file extension",
26
  # info="Only txt and page xml is supported for now!",
27
+ scale=1,
28
  )
29
 
30
  htr_pipeline_button = gr.Button(
31
+ "Run HTR", variant="primary", visible=True, elem_id="run_pipeline_button", scale=1
32
+ )
 
 
 
 
33
  with gr.Group():
34
  with gr.Row():
35
  fast_file_downlod = gr.File(label="Download output file", visible=False)
 
60
  fast_track_output_image = gr.Image(label="HTR results visualizer", type="numpy", tool="editor", height=650)
61
 
62
  with gr.Row(visible=False) as api_placeholder:
63
+ htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
 
 
 
 
64
 
65
  xml_rendered_placeholder_for_api = gr.Textbox(visible=False)
66
  htr_pipeline_button.click(
 
75
  outputs=[xml_rendered_placeholder_for_api],
76
  api_name="predict",
77
  )
 
 
 
 
 
tabs/stepwise_htr_tool.py CHANGED
@@ -25,7 +25,8 @@ with gr.Blocks() as stepwise_htr_tool_tab:
25
  label="Image to Region segment",
26
  # type="numpy",
27
  tool="editor",
28
- ).style(height=350)
 
29
 
30
  with gr.Accordion("Region segment settings:", open=False):
31
  with gr.Row():
@@ -63,7 +64,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
63
  "Segment Region",
64
  variant="primary",
65
  elem_id="region_segment_button",
66
- ) # .style(full_width=False)
67
 
68
  with gr.Row():
69
  with gr.Accordion("Example images to use:", open=False) as example_accord:
@@ -75,7 +76,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
75
  )
76
 
77
  with gr.Column(scale=3):
78
- output_region_image = gr.Image(label="Segmented regions", type="numpy").style(height=600)
79
 
80
  ##############################################
81
  with gr.Tab("2. Line Segmentation"):
@@ -84,27 +85,27 @@ with gr.Blocks() as stepwise_htr_tool_tab:
84
  # type="numpy",
85
  interactive="False",
86
  visible=True,
87
- ).style(height=600)
 
88
 
89
  with gr.Row(visible=False) as control_line_segment:
90
  with gr.Column(scale=2):
91
  with gr.Box():
92
  regions_cropped_gallery = gr.Gallery(
93
  label="Segmented regions",
94
- show_label=False,
95
  elem_id="gallery",
96
- ).style(
97
  columns=[2],
98
  rows=[2],
99
  # object_fit="contain",
100
- height=400,
101
  preview=True,
102
  container=False,
103
  )
104
 
105
  input_region_from_gallery = gr.Image(
106
- label="Region segmentation to line segment", interactive="False", visible=False
107
- ).style(height=400)
 
108
  with gr.Row():
109
  with gr.Accordion("Line segment settings:", open=False):
110
  with gr.Row():
@@ -126,7 +127,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
126
  info="""The minimum required overlap or similarity
127
  for a detected region or object to be considered valid""",
128
  )
129
- with gr.Row().style(equal_height=False):
130
  line_segment_model_dropdown = gr.Dropdown(
131
  choices=["Riksarkivet/RmtDet_lines"],
132
  value="Riksarkivet/RmtDet_lines",
@@ -138,22 +139,22 @@ with gr.Blocks() as stepwise_htr_tool_tab:
138
  " ",
139
  variant="Secondary",
140
  # elem_id="center_button",
141
- ).style(full_width=True)
 
142
 
143
  line_segment_button = gr.Button(
144
  "Segment Lines",
145
  variant="primary",
146
  # elem_id="center_button",
147
- ).style(full_width=True)
 
148
 
149
  with gr.Column(scale=3):
150
  # gr.Markdown("""lorem ipsum""")
151
 
152
  output_line_from_region = gr.Image(
153
- label="Segmented lines",
154
- type="numpy",
155
- interactive="False",
156
- ).style(height=600)
157
 
158
  ###############################################
159
  with gr.Tab("3. Transcribe Text"):
@@ -162,19 +163,16 @@ with gr.Blocks() as stepwise_htr_tool_tab:
162
  # type="numpy",
163
  interactive="False",
164
  visible=True,
165
- ).style(height=600)
 
166
 
167
  with gr.Row(visible=False) as control_htr:
168
  inputs_lines_to_transcribe = gr.Variable()
169
 
170
  with gr.Column(scale=2):
171
  image_inputs_lines_to_transcribe = gr.Image(
172
- label="Transcribed lines",
173
- type="numpy",
174
- interactive="False",
175
- visible=False,
176
- ).style(height=470)
177
-
178
  with gr.Row():
179
  with gr.Accordion("Transcribe settings:", open=False):
180
  transcriber_model = gr.Dropdown(
@@ -184,30 +182,21 @@ with gr.Blocks() as stepwise_htr_tool_tab:
184
  info="Will add more models later!",
185
  )
186
  with gr.Row():
187
- clear_transcribe_button = gr.Button(" ", variant="Secondary", visible=True).style(
188
- full_width=True
189
- )
190
- transcribe_button = gr.Button("Transcribe lines", variant="primary", visible=True).style(
191
- full_width=True
192
- )
193
 
194
- donwload_txt_button = gr.Button("Download text", variant="secondary", visible=False).style(
195
- full_width=True
196
- )
197
-
198
- with gr.Row():
199
- txt_file_downlod = gr.File(label="Download text", visible=False)
200
 
201
  with gr.Column(scale=3):
202
  with gr.Row():
203
  transcribed_text_df = gr.Dataframe(
204
  headers=["Transcribed text"],
205
- max_rows=15,
206
  col_count=(1, "fixed"),
207
  wrap=True,
208
  interactive=False,
209
  overflow_row_behaviour="paginate",
210
- ).style(height=600)
 
211
 
212
  #####################################
213
  with gr.Tab("4. Explore Results"):
@@ -216,35 +205,43 @@ with gr.Blocks() as stepwise_htr_tool_tab:
216
  # type="numpy",
217
  interactive="False",
218
  visible=True,
219
- ).style(height=600)
 
220
 
221
- with gr.Row(visible=False) as control_results_transcribe:
222
  with gr.Column(scale=1, visible=True):
223
  with gr.Box():
224
  temp_gallery_input = gr.Variable()
225
 
226
  gallery_inputs_lines_to_transcribe = gr.Gallery(
227
  label="Cropped transcribed lines",
228
- show_label=True,
229
  elem_id="gallery_lines",
230
- ).style(
231
  columns=[3],
232
  rows=[3],
233
  # object_fit="contain",
234
- # height="600",
235
  preview=True,
236
  container=False,
237
  )
 
 
 
 
 
 
 
 
238
  with gr.Column(scale=1, visible=True):
239
  mapping_dict = gr.Variable()
240
  transcribed_text_df_finish = gr.Dataframe(
241
- headers=["Transcribed text", "Pred score"],
242
- max_rows=15,
243
  col_count=(2, "fixed"),
244
  wrap=True,
245
  interactive=False,
246
  overflow_row_behaviour="paginate",
247
- ).style(height=600)
 
248
 
249
  # custom track
250
  region_segment_button.click(
@@ -260,7 +257,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
260
  transcribed_text_df_finish.select(
261
  fn=custom_track.get_select_index_df,
262
  inputs=[transcribed_text_df_finish, mapping_dict],
263
- outputs=gallery_inputs_lines_to_transcribe,
264
  )
265
 
266
  line_segment_button.click(
@@ -287,23 +284,12 @@ with gr.Blocks() as stepwise_htr_tool_tab:
287
  transcribed_text_df,
288
  transcribed_text_df_finish,
289
  mapping_dict,
290
- txt_file_downlod,
291
  control_results_transcribe,
292
  image_placeholder_explore_results,
293
  ],
294
  )
295
 
296
- donwload_txt_button.click(
297
- custom_track.download_df_to_txt,
298
- inputs=transcribed_text_df,
299
- outputs=[txt_file_downlod, txt_file_downlod],
300
- )
301
-
302
- # def remove_temp_vis():
303
- # if os.path.exists("./vis_data"):
304
- # os.remove("././vis_data")
305
- # return None
306
-
307
  clear_button.click(
308
  lambda: (
309
  (shutil.rmtree("./vis_data") if os.path.exists("./vis_data") else None, None)[1],
 
25
  label="Image to Region segment",
26
  # type="numpy",
27
  tool="editor",
28
+ height=350,
29
+ )
30
 
31
  with gr.Accordion("Region segment settings:", open=False):
32
  with gr.Row():
 
64
  "Segment Region",
65
  variant="primary",
66
  elem_id="region_segment_button",
67
+ )
68
 
69
  with gr.Row():
70
  with gr.Accordion("Example images to use:", open=False) as example_accord:
 
76
  )
77
 
78
  with gr.Column(scale=3):
79
+ output_region_image = gr.Image(label="Segmented regions", type="numpy", height=600)
80
 
81
  ##############################################
82
  with gr.Tab("2. Line Segmentation"):
 
85
  # type="numpy",
86
  interactive="False",
87
  visible=True,
88
+ height=600,
89
+ )
90
 
91
  with gr.Row(visible=False) as control_line_segment:
92
  with gr.Column(scale=2):
93
  with gr.Box():
94
  regions_cropped_gallery = gr.Gallery(
95
  label="Segmented regions",
 
96
  elem_id="gallery",
 
97
  columns=[2],
98
  rows=[2],
99
  # object_fit="contain",
100
+ height=450,
101
  preview=True,
102
  container=False,
103
  )
104
 
105
  input_region_from_gallery = gr.Image(
106
+ label="Region segmentation to line segment", interactive="False", visible=False, height=400
107
+ )
108
+
109
  with gr.Row():
110
  with gr.Accordion("Line segment settings:", open=False):
111
  with gr.Row():
 
127
  info="""The minimum required overlap or similarity
128
  for a detected region or object to be considered valid""",
129
  )
130
+ with gr.Row(equal_height=False):
131
  line_segment_model_dropdown = gr.Dropdown(
132
  choices=["Riksarkivet/RmtDet_lines"],
133
  value="Riksarkivet/RmtDet_lines",
 
139
  " ",
140
  variant="Secondary",
141
  # elem_id="center_button",
142
+ scale=1,
143
+ )
144
 
145
  line_segment_button = gr.Button(
146
  "Segment Lines",
147
  variant="primary",
148
  # elem_id="center_button",
149
+ scale=1,
150
+ )
151
 
152
  with gr.Column(scale=3):
153
  # gr.Markdown("""lorem ipsum""")
154
 
155
  output_line_from_region = gr.Image(
156
+ label="Segmented lines", type="numpy", interactive="False", height=600
157
+ )
 
 
158
 
159
  ###############################################
160
  with gr.Tab("3. Transcribe Text"):
 
163
  # type="numpy",
164
  interactive="False",
165
  visible=True,
166
+ height=600,
167
+ )
168
 
169
  with gr.Row(visible=False) as control_htr:
170
  inputs_lines_to_transcribe = gr.Variable()
171
 
172
  with gr.Column(scale=2):
173
  image_inputs_lines_to_transcribe = gr.Image(
174
+ label="Transcribed lines", type="numpy", interactive="False", visible=False, height=470
175
+ )
 
 
 
 
176
  with gr.Row():
177
  with gr.Accordion("Transcribe settings:", open=False):
178
  transcriber_model = gr.Dropdown(
 
182
  info="Will add more models later!",
183
  )
184
  with gr.Row():
185
+ clear_transcribe_button = gr.Button(" ", variant="Secondary", visible=True, scale=1)
 
 
 
 
 
186
 
187
+ transcribe_button = gr.Button("Transcribe Lines", variant="primary", visible=True, scale=1)
 
 
 
 
 
188
 
189
  with gr.Column(scale=3):
190
  with gr.Row():
191
  transcribed_text_df = gr.Dataframe(
192
  headers=["Transcribed text"],
193
+ max_rows=14,
194
  col_count=(1, "fixed"),
195
  wrap=True,
196
  interactive=False,
197
  overflow_row_behaviour="paginate",
198
+ height=600,
199
+ )
200
 
201
  #####################################
202
  with gr.Tab("4. Explore Results"):
 
205
  # type="numpy",
206
  interactive="False",
207
  visible=True,
208
+ height=600,
209
+ )
210
 
211
+ with gr.Row(visible=False, equal_height=False) as control_results_transcribe:
212
  with gr.Column(scale=1, visible=True):
213
  with gr.Box():
214
  temp_gallery_input = gr.Variable()
215
 
216
  gallery_inputs_lines_to_transcribe = gr.Gallery(
217
  label="Cropped transcribed lines",
 
218
  elem_id="gallery_lines",
 
219
  columns=[3],
220
  rows=[3],
221
  # object_fit="contain",
222
+ height=300,
223
  preview=True,
224
  container=False,
225
  )
226
+
227
+ dataframe_text_index = gr.Textbox(
228
+ label="Text from DataFrame selection",
229
+ info="Click on a dataframe cell to view the corresponding transcribed text line crop. You can also sort the dataframe to easily locate specific entries.",
230
+ lines=2,
231
+ interactive=False,
232
+ )
233
+
234
  with gr.Column(scale=1, visible=True):
235
  mapping_dict = gr.Variable()
236
  transcribed_text_df_finish = gr.Dataframe(
237
+ headers=["Transcribed text", "pred score"],
238
+ max_rows=14,
239
  col_count=(2, "fixed"),
240
  wrap=True,
241
  interactive=False,
242
  overflow_row_behaviour="paginate",
243
+ height=600,
244
+ )
245
 
246
  # custom track
247
  region_segment_button.click(
 
257
  transcribed_text_df_finish.select(
258
  fn=custom_track.get_select_index_df,
259
  inputs=[transcribed_text_df_finish, mapping_dict],
260
+ outputs=[gallery_inputs_lines_to_transcribe, dataframe_text_index],
261
  )
262
 
263
  line_segment_button.click(
 
284
  transcribed_text_df,
285
  transcribed_text_df_finish,
286
  mapping_dict,
287
+ # Hide
288
  control_results_transcribe,
289
  image_placeholder_explore_results,
290
  ],
291
  )
292
 
 
 
 
 
 
 
 
 
 
 
 
293
  clear_button.click(
294
  lambda: (
295
  (shutil.rmtree("./vis_data") if os.path.exists("./vis_data") else None, None)[1],