Gabriel commited on
Commit
c60ebd1
1 Parent(s): d1b10a8

0.0.3 release with Trocr and compare support

Browse files
.github/README.md CHANGED
@@ -15,6 +15,13 @@ HTRFLOW is more than just a demo; it's a testament to the advancement of open so
15
 
16
  ## Run app
17
 
 
 
 
 
 
 
 
18
  Install libraries with Makefile:
19
 
20
  ```
 
15
 
16
  ## Run app
17
 
18
+ Use virtual env.
19
+
20
+ ```
21
+ python3 -m venv .venv
22
+ source .venv/bin/activate
23
+ ```
24
+
25
  Install libraries with Makefile:
26
 
27
  ```
.gitignore CHANGED
@@ -28,4 +28,11 @@ TODO.md
28
  .cache_images/
29
  traffic_data.db
30
  ip_data.csv
31
- data/
 
 
 
 
 
 
 
 
28
  .cache_images/
29
  traffic_data.db
30
  ip_data.csv
31
+ data/
32
+
33
+ #mlflow
34
+ mlruns/
35
+ test.ipynb
36
+
37
+ #models
38
+ models--Riksarkivet--HTR_pipeline_models/
app.py CHANGED
@@ -31,6 +31,21 @@ with gr.Blocks(title="Riksarkivet", theme=theme, css=css) as demo:
31
  with gr.Tab("Overview"):
32
  overview.render()
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  SECRET_KEY = os.environ.get("HUB_TOKEN", False)
35
  if SECRET_KEY:
36
  demo.load(
 
31
  with gr.Tab("Overview"):
32
  overview.render()
33
 
34
+ with gr.Tab("How to use"):
35
+ with gr.Row():
36
+ with gr.Column():
37
+ gr.Markdown("## Fast track")
38
+ gr.Video(
39
+ value="https://github.com/Borg93/htr_gradio_file_placeholder/blob/main/eating_spaghetti.mp4",
40
+ format="mp4",
41
+ )
42
+ with gr.Column():
43
+ gr.Markdown("## Stepwise")
44
+ gr.Video(
45
+ "https://github.com/Borg93/htr_gradio_file_placeholder/blob/main/htr_tool_media_cut.mp4",
46
+ format="mp4",
47
+ )
48
+
49
  SECRET_KEY = os.environ.get("HUB_TOKEN", False)
50
  if SECRET_KEY:
51
  demo.load(
helper/text/overview/changelog_roadmap/changelog.md CHANGED
@@ -2,38 +2,13 @@
2
 
3
  All notable changes to HTRFLOW will be documented here.
4
 
5
- ### [0.0.2] - 2023-11-01
6
 
7
  #### Added
8
 
9
- - Better documentation for API, see **Overview** > **Duplicating for own use & API**
10
- - Better documentation for restrictions of app, see **Overview** > **HTRFLOW**
11
 
12
  #### Fixed
13
 
14
- - Fixed bug for API, [issue](https://github.com/Riksarkivet/HTRFLOW/issues/2)
15
-
16
- #### Changed
17
-
18
- - Changed named for **FAQ & Discussion** to **FAQ & Contact**
19
-
20
- ---
21
-
22
- ### [0.0.1] - 2023-10-23
23
-
24
- #### Added
25
-
26
- - Added a new feature to **Stepwise** > **Explore results** > New Text diff and CER component
27
-
28
- #### Fixed
29
-
30
- - Fixed naming conventions of tabs in app so they are more coherent with the code.
31
-
32
- #### Changed
33
-
34
- - Changed the layout in both Fast track and Stepwise to improve the UX
35
-
36
- - Examples are viewed in the middle of the layout
37
- - "Advanced settings" are initial hidden
38
-
39
- - Removed **help** tab for now (documentation of Fast track and Stepwise will come in a later release)
 
2
 
3
  All notable changes to HTRFLOW will be documented here.
4
 
5
+ ### [0.0.3] - 2023-11-06
6
 
7
  #### Added
8
 
9
+ - Support for TROCR -> Latin and Eng model
10
+ - New feature! Compare different runs with GT, see tab **Fast track** > **Compare**
11
 
12
  #### Fixed
13
 
14
+ - Fixed bug for Docker and running app locally, [issue](https://github.com/Riksarkivet/HTRFLOW/issues/2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
helper/text/overview/changelog_roadmap/old_changelog.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Changelog
2
+
3
+ All notable changes to HTRFLOW will be documented here.
4
+
5
+ ### [0.0.2] - 2023-11-01
6
+
7
+ #### Added
8
+
9
+ - Better documentation for API, see **Overview** > **Duplicating for own use & API**
10
+ - Better documentation for restrictions of app, see **Overview** > **HTRFLOW**
11
+
12
+ #### Fixed
13
+
14
+ - Fixed bug for API, [issue](https://github.com/Riksarkivet/HTRFLOW/issues/2)
15
+
16
+ #### Changed
17
+
18
+ - Changed named for **FAQ & Discussion** to **FAQ & Contact**
19
+
20
+ ---
21
+
22
+ ### [0.0.1] - 2023-10-23
23
+
24
+ #### Added
25
+
26
+ - Added a new feature to **Stepwise** > **Explore results** > New Text diff and CER component
27
+
28
+ #### Fixed
29
+
30
+ - Fixed naming conventions of tabs in app so they are more coherent with the code.
31
+
32
+ #### Changed
33
+
34
+ - Changed the layout in both Fast track and Stepwise to improve the UX
35
+
36
+ - Examples are viewed in the middle of the layout
37
+ - "Advanced settings" are initial hidden
38
+
39
+ - Removed **help** tab for now (documentation of Fast track and Stepwise will come in a later release)
helper/text/text_app.py CHANGED
@@ -1,5 +1,5 @@
1
  class TextApp:
2
- demo_version = """<em>Version 0.0.2</em>"""
3
 
4
  title_markdown = """
5
 
 
1
  class TextApp:
2
+ demo_version = """<em>Version 0.0.3</em>"""
3
 
4
  title_markdown = """
5
 
helper/text/text_overview.py CHANGED
@@ -21,6 +21,8 @@ class TextOverview:
21
 
22
  # Changelog & Roadmap
23
  changelog = read_markdown("helper/text/overview/changelog_roadmap/changelog.md")
 
 
24
  roadmap = read_markdown("helper/text/overview/changelog_roadmap/roadmap.md")
25
 
26
  # duplicate & api
 
21
 
22
  # Changelog & Roadmap
23
  changelog = read_markdown("helper/text/overview/changelog_roadmap/changelog.md")
24
+ old_changelog = read_markdown("helper/text/overview/changelog_roadmap/old_changelog.md")
25
+
26
  roadmap = read_markdown("helper/text/overview/changelog_roadmap/roadmap.md")
27
 
28
  # duplicate & api
src/htr_pipeline/gradio_backend.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
 
 
2
 
3
  import cv2
 
4
  import gradio as gr
5
  import numpy as np
6
  import pandas as pd
@@ -36,7 +39,7 @@ class FastTrack:
36
  def __init__(self, model_loader):
37
  self.pipeline: PipelineInterface = model_loader.pipeline
38
 
39
- def segment_to_xml(self, image, radio_button_choices):
40
  handling_callback_stop_inferencer()
41
 
42
  gr.Info("Excuting HTR on image")
@@ -46,7 +49,9 @@ class FastTrack:
46
  if os.path.exists(f"./{xml_xml}"):
47
  os.remove(f"./{xml_xml}")
48
 
49
- rendered_xml = self.pipeline.running_htr_pipeline(image)
 
 
50
 
51
  with open(xml_xml, "w") as f:
52
  f.write(rendered_xml)
@@ -172,13 +177,83 @@ class CustomTrack:
172
 
173
  return file_name, gr.update(visible=True)
174
 
175
- # def transcribe_text_another_model(self, df, images):
176
- # transcription_temp_list = []
177
- # for image in images:
178
- # transcribed_text = inferencer.transcribe_different_model(image)
179
- # transcription_temp_list.append(transcribed_text)
180
- # df_trans = pd.DataFrame(transcription_temp_list, columns=["Transcribed_text"])
181
- # yield df_trans, df_trans, gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
 
184
  if __name__ == "__main__":
 
1
  import os
2
+ import xml.etree.ElementTree as ET
3
+ from difflib import Differ
4
 
5
  import cv2
6
+ import evaluate
7
  import gradio as gr
8
  import numpy as np
9
  import pandas as pd
 
39
  def __init__(self, model_loader):
40
  self.pipeline: PipelineInterface = model_loader.pipeline
41
 
42
+ def segment_to_xml(self, image, radio_button_choices, htr_tool_transcriber_model_dropdown):
43
  handling_callback_stop_inferencer()
44
 
45
  gr.Info("Excuting HTR on image")
 
49
  if os.path.exists(f"./{xml_xml}"):
50
  os.remove(f"./{xml_xml}")
51
 
52
+ htr_tool_transcriber_model_dropdown
53
+
54
+ rendered_xml = self.pipeline.running_htr_pipeline(image, htr_tool_transcriber_model_dropdown)
55
 
56
  with open(xml_xml, "w") as f:
57
  f.write(rendered_xml)
 
177
 
178
  return file_name, gr.update(visible=True)
179
 
180
+
181
+ # Temporary structured here...
182
+
183
+
184
+ def upload_file(files):
185
+ return files.name, gr.update(visible=True)
186
+
187
+
188
+ def diff_texts(text1, text2):
189
+ d = Differ()
190
+ return [(token[2:], token[0] if token[0] != " " else None) for token in d.compare(text1, text2)]
191
+
192
+
193
+ def compute_cer_a_and_b_with_gt(run_a, run_b, run_gt):
194
+ text_run_a, text_run_b, text_run_gt = reading_xml_files_string(run_a, run_b, run_gt)
195
+
196
+ cer_metric = evaluate.load("cer")
197
+
198
+ if text_run_a == text_run_gt:
199
+ return "No Ground Truth was provided."
200
+
201
+ elif text_run_a == text_run_b:
202
+ return f"A & B -> GT: {round(cer_metric.compute(predictions=[text_run_a], references=[text_run_gt]), 4)}"
203
+
204
+ else:
205
+ return f"A -> GT: {round(cer_metric.compute(predictions=[text_run_a], references=[text_run_gt]), 4)}, B -> GT {round(cer_metric.compute(predictions=[text_run_b], references=[text_run_gt]), 4)}"
206
+
207
+
208
+ def temporary_xml_parser(page_xml):
209
+ tree = ET.parse(page_xml, parser=ET.XMLParser(encoding="utf-8"))
210
+ root = tree.getroot()
211
+ namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
212
+ text_list = []
213
+ for textregion in root.findall(f".//{namespace}TextRegion"):
214
+ for textline in textregion.findall(f".//{namespace}TextLine"):
215
+ text = textline.find(f"{namespace}TextEquiv").find(f"{namespace}Unicode").text
216
+ text_list.append(text)
217
+ return " ".join(text_list)
218
+
219
+
220
+ def compare_diff_runs_highlight(run_a, run_b, run_gt):
221
+ text_run_a, text_run_b, text_run_gt = reading_xml_files_string(run_a, run_b, run_gt)
222
+
223
+ diff_runs = diff_texts(text_run_a, text_run_b)
224
+ diff_gt = diff_texts(text_run_a, text_run_gt)
225
+
226
+ return diff_runs, diff_gt
227
+
228
+
229
+ def reading_xml_files_string(run_a, run_b, run_gt):
230
+ if run_a is None:
231
+ return
232
+
233
+ if run_gt is None:
234
+ gr.Warning("No GT was provided, setting GT to A")
235
+ run_gt = run_a
236
+
237
+ if run_b is None:
238
+ gr.Warning("No B was provided, setting B to A")
239
+ run_b = run_a
240
+
241
+ text_run_a = temporary_xml_parser(run_a.name)
242
+ text_run_b = temporary_xml_parser(run_b.name)
243
+ text_run_gt = temporary_xml_parser(run_gt.name)
244
+ return text_run_a, text_run_b, text_run_gt
245
+
246
+
247
+ def update_selected_tab_output_and_setting():
248
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
249
+
250
+
251
+ def update_selected_tab_image_viewer():
252
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
253
+
254
+
255
+ def update_selected_tab_model_compare():
256
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
257
 
258
 
259
  if __name__ == "__main__":
src/htr_pipeline/inferencer.py CHANGED
@@ -3,6 +3,8 @@ from typing import Protocol, Tuple
3
  import gradio as gr
4
  import mmcv
5
  import numpy as np
 
 
6
 
7
  from src.htr_pipeline.models import HtrModels
8
  from src.htr_pipeline.utils.filter_segmask import FilterSegMask
@@ -116,20 +118,28 @@ class Inferencer:
116
  result_rec = self.htr_model_inferencer(line_cropped)
117
  return result_rec["predictions"][0]["text"], round(result_rec["predictions"][0]["scores"], 4)
118
 
119
- # def transcribe_different_model(self, image):
120
- # processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
121
- # model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # # prepare image
124
- # pixel_values = processor(image, return_tensors="pt").pixel_values
125
 
126
- # # generate (no beam search)
127
- # generated_ids = model.generate(pixel_values)
128
 
129
- # # decode
130
- # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
131
 
132
- # return generated_text
133
 
134
 
135
  class InferencerInterface(Protocol):
 
3
  import gradio as gr
4
  import mmcv
5
  import numpy as np
6
+ import torch
7
+ from transformers import AutoImageProcessor, TrOCRProcessor, VisionEncoderDecoderModel
8
 
9
  from src.htr_pipeline.models import HtrModels
10
  from src.htr_pipeline.utils.filter_segmask import FilterSegMask
 
118
  result_rec = self.htr_model_inferencer(line_cropped)
119
  return result_rec["predictions"][0]["text"], round(result_rec["predictions"][0]["scores"], 4)
120
 
121
+ @timer_func
122
+ def transcribe_different_model(self, image, htr_tool_transcriber_model_dropdown):
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+
125
+ if htr_tool_transcriber_model_dropdown == "pstroe/bullinger-general-model":
126
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
127
+ image_processor = AutoImageProcessor.from_pretrained("pstroe/bullinger-general-model")
128
+ model = VisionEncoderDecoderModel.from_pretrained("pstroe/bullinger-general-model")
129
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
130
+
131
+ else:
132
+ processor = TrOCRProcessor.from_pretrained(htr_tool_transcriber_model_dropdown)
133
+ model = VisionEncoderDecoderModel.from_pretrained(htr_tool_transcriber_model_dropdown)
134
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
135
 
136
+ model.to(device)
 
137
 
138
+ generated_ids = model.generate(pixel_values)
 
139
 
140
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
141
 
142
+ return generated_text, 1.0
143
 
144
 
145
  class InferencerInterface(Protocol):
src/htr_pipeline/models.py CHANGED
@@ -11,26 +11,28 @@ from mmocr.apis import TextRecInferencer
11
  class HtrModels:
12
  def __init__(self, local_run=False):
13
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
- SECRET_KEY = os.environ.get("AM_I_IN_A_DOCKER_CONTAINER", False)
15
 
16
  model_folder = "./models"
17
  self.region_config = f"{model_folder}/RmtDet_regions/rtmdet_m_textregions_2_concat.py"
 
18
 
19
  self.line_config = f"{model_folder}/RmtDet_lines/rtmdet_m_textlines_2_concat.py"
20
  self.line_checkpoint = f"{model_folder}/RmtDet_lines/epoch_12.pth"
 
21
  self.mmocr_config = f"{model_folder}/SATRN/_base_satrn_shallow_concat.py"
 
22
 
23
- if SECRET_KEY:
 
 
 
 
 
24
  config_path = self.get_config()
25
  self.region_checkpoint = config_path["region_checkpoint"]
26
  self.line_checkpoint = config_path["line_checkpoint"]
27
  self.mmocr_checkpoint = config_path["mmocr_checkpoint"]
28
 
29
- else:
30
- self.region_checkpoint = f"{model_folder}/RmtDet_regions/epoch_12.pth"
31
- self.line_checkpoint = f"{model_folder}/RmtDet_lines/epoch_12.pth"
32
- self.mmocr_checkpoint = f"{model_folder}/SATRN/epoch_5.pth"
33
-
34
  def load_region_model(self):
35
  # build the model from a config file and a checkpoint file
36
  return DetInferencer(self.region_config, self.region_checkpoint, device=self.device)
 
11
  class HtrModels:
12
  def __init__(self, local_run=False):
13
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
14
 
15
  model_folder = "./models"
16
  self.region_config = f"{model_folder}/RmtDet_regions/rtmdet_m_textregions_2_concat.py"
17
+ self.region_checkpoint = f"{model_folder}/RmtDet_regions/epoch_12.pth"
18
 
19
  self.line_config = f"{model_folder}/RmtDet_lines/rtmdet_m_textlines_2_concat.py"
20
  self.line_checkpoint = f"{model_folder}/RmtDet_lines/epoch_12.pth"
21
+
22
  self.mmocr_config = f"{model_folder}/SATRN/_base_satrn_shallow_concat.py"
23
+ self.mmocr_checkpoint = f"{model_folder}/SATRN/epoch_5.pth"
24
 
25
+ # Check if model files exist at the specified paths, if not, get the config
26
+ if not (
27
+ os.path.exists(self.region_checkpoint)
28
+ and os.path.exists(self.line_checkpoint)
29
+ and os.path.exists(self.mmocr_checkpoint)
30
+ ):
31
  config_path = self.get_config()
32
  self.region_checkpoint = config_path["region_checkpoint"]
33
  self.line_checkpoint = config_path["line_checkpoint"]
34
  self.mmocr_checkpoint = config_path["mmocr_checkpoint"]
35
 
 
 
 
 
 
36
  def load_region_model(self):
37
  # build the model from a config file and a checkpoint file
38
  return DetInferencer(self.region_config, self.region_checkpoint, device=self.device)
src/htr_pipeline/pipeline.py CHANGED
@@ -23,6 +23,7 @@ class Pipeline:
23
  def running_htr_pipeline(
24
  self,
25
  input_image: np.ndarray,
 
26
  pred_score_threshold_regions: float = 0.4,
27
  pred_score_threshold_lines: float = 0.4,
28
  containments_threshold: float = 0.5,
@@ -31,7 +32,12 @@ class Pipeline:
31
  image = mmcv.imread(input_image)
32
 
33
  rendered_xml = self.pipeline_inferencer.image_to_page_xml(
34
- image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, self.inferencer
 
 
 
 
 
35
  )
36
 
37
  return rendered_xml
 
23
  def running_htr_pipeline(
24
  self,
25
  input_image: np.ndarray,
26
+ htr_tool_transcriber_model_dropdown,
27
  pred_score_threshold_regions: float = 0.4,
28
  pred_score_threshold_lines: float = 0.4,
29
  containments_threshold: float = 0.5,
 
32
  image = mmcv.imread(input_image)
33
 
34
  rendered_xml = self.pipeline_inferencer.image_to_page_xml(
35
+ image,
36
+ htr_tool_transcriber_model_dropdown,
37
+ pred_score_threshold_regions,
38
+ pred_score_threshold_lines,
39
+ containments_threshold,
40
+ self.inferencer,
41
  )
42
 
43
  return rendered_xml
src/htr_pipeline/utils/pipeline_inferencer.py CHANGED
@@ -15,8 +15,17 @@ class PipelineInferencer:
15
  self.xml_helper = xml_helper
16
 
17
  def image_to_page_xml(
18
- self, image, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold, inferencer
 
 
 
 
 
 
19
  ):
 
 
 
20
  template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
21
  template_data["textRegions"] = self._process_regions(
22
  image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
@@ -121,7 +130,14 @@ class PipelineInferencer:
121
  def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
122
  line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
123
 
124
- transcribed_text, htr_score = inferencer.transcribe(line)
 
 
 
 
 
 
 
125
  line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
126
  line_data["pred_score"] = round(htr_score, 4)
127
 
 
15
  self.xml_helper = xml_helper
16
 
17
  def image_to_page_xml(
18
+ self,
19
+ image,
20
+ htr_tool_transcriber_model_dropdown,
21
+ pred_score_threshold_regions,
22
+ pred_score_threshold_lines,
23
+ containments_threshold,
24
+ inferencer,
25
  ):
26
+ # temporary solutions.. for trocr..
27
+ self.htr_tool_transcriber_model_dropdown = htr_tool_transcriber_model_dropdown
28
+
29
  template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
30
  template_data["textRegions"] = self._process_regions(
31
  image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
 
130
  def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
131
  line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
132
 
133
+ # temporary solution..
134
+ if self.htr_tool_transcriber_model_dropdown == "Riksarkivet/satrn_htr":
135
+ transcribed_text, htr_score = inferencer.transcribe(line)
136
+ else:
137
+ transcribed_text, htr_score = inferencer.transcribe_different_model(
138
+ line, self.htr_tool_transcriber_model_dropdown
139
+ )
140
+
141
  line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
142
  line_data["pred_score"] = round(htr_score, 4)
143
 
tabs/htr_tool.py CHANGED
@@ -4,7 +4,16 @@ import gradio as gr
4
 
5
  from helper.examples.examples import DemoImages
6
  from helper.utils import TrafficDataHandler
7
- from src.htr_pipeline.gradio_backend import FastTrack, SingletonModelLoader
 
 
 
 
 
 
 
 
 
8
 
9
  model_loader = SingletonModelLoader()
10
  fast_track = FastTrack(model_loader)
@@ -55,17 +64,14 @@ with gr.Blocks() as htr_tool_tab:
55
  )
56
 
57
  with gr.Tab("Compare") as tab_model_compare_selector:
58
- with gr.Box():
59
- gr.Markdown(
60
- """
61
- **Work in progress**
62
-
63
- Compare different runs with uploaded Ground Truth and calculate CER. You will also be able to upload output format files
64
-
65
- """
66
- )
67
-
68
  calc_cer_button_fast = gr.Button("Calculate CER", variant="primary", visible=True)
 
 
 
 
 
69
 
70
  with gr.Column(scale=4):
71
  with gr.Box():
@@ -142,7 +148,11 @@ with gr.Blocks() as htr_tool_tab:
142
 
143
  with gr.Row():
144
  htr_tool_transcriber_model_dropdown = gr.Dropdown(
145
- choices=["Riksarkivet/satrn_htr", "microsoft/trocr-base-handwritten"],
 
 
 
 
146
  value="Riksarkivet/satrn_htr",
147
  label="Text recognition models",
148
  info="More models will be added",
@@ -167,50 +177,62 @@ with gr.Blocks() as htr_tool_tab:
167
  )
168
 
169
  with gr.Column(visible=False) as model_compare_selector:
170
- gr.Markdown("**Work in progress:**")
171
  with gr.Row():
172
- gr.Radio(
173
- choices=["Compare Page XML", "Compare different runs"],
174
- value="Compare Page XML",
175
- info="Compare different runs from HTRFLOW or with external runs.",
176
- )
177
  with gr.Row():
178
- gr.UploadButton(label="Run A")
179
-
180
- gr.UploadButton(label="Run B")
 
 
 
 
 
 
 
181
 
182
- gr.UploadButton(label="Ground Truth")
 
 
 
 
 
 
 
 
 
183
 
184
- with gr.Row():
185
- gr.HighlightedText(
186
- label="Text diff runs",
 
 
 
 
 
 
 
 
 
 
187
  combine_adjacent=True,
188
  show_legend=True,
189
  color_map={"+": "red", "-": "green"},
190
  )
191
-
192
- with gr.Row():
193
- gr.HighlightedText(
194
- label="Text diff ground truth",
195
  combine_adjacent=True,
196
  show_legend=True,
197
  color_map={"+": "red", "-": "green"},
198
  )
199
 
200
- with gr.Row():
201
- with gr.Column(scale=1):
202
- with gr.Row(equal_height=False):
203
- cer_output_fast = gr.Textbox(label="CER:")
204
- with gr.Column(scale=2):
205
- pass
206
-
207
  xml_rendered_placeholder_for_api = gr.Textbox(placeholder="XML", visible=False)
208
 
209
  htr_event_click_event = htr_pipeline_button.click(
210
  fast_track.segment_to_xml,
211
- inputs=[fast_track_input_region_image, radio_file_input],
212
  outputs=[fast_file_downlod, fast_file_downlod],
213
- queue=False,
214
  api_name=False,
215
  )
216
 
@@ -222,44 +244,21 @@ with gr.Blocks() as htr_tool_tab:
222
  api_name="run_htr_pipeline",
223
  )
224
 
225
- def dummy_update_htr_tool_transcriber_model_dropdown(htr_tool_transcriber_model_dropdown):
226
- return gr.update(value="Riksarkivet/satrn_htr")
227
-
228
- htr_tool_transcriber_model_dropdown.change(
229
- fn=dummy_update_htr_tool_transcriber_model_dropdown,
230
- inputs=htr_tool_transcriber_model_dropdown,
231
- outputs=htr_tool_transcriber_model_dropdown,
232
- queue=False,
233
- api_name=False,
234
- )
235
-
236
- def update_selected_tab_output_and_setting():
237
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
238
-
239
- def update_selected_tab_image_viewer():
240
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
241
-
242
- def update_selected_tab_model_compare():
243
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
244
-
245
  tab_output_and_setting_selector.select(
246
  fn=update_selected_tab_output_and_setting,
247
  outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
248
- queue=False,
249
  api_name=False,
250
  )
251
 
252
  tab_image_viewer_selector.select(
253
  fn=update_selected_tab_image_viewer,
254
  outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
255
- queue=False,
256
  api_name=False,
257
  )
258
 
259
  tab_model_compare_selector.select(
260
  fn=update_selected_tab_model_compare,
261
  outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
262
- queue=False,
263
  api_name=False,
264
  )
265
 
@@ -273,7 +272,6 @@ with gr.Blocks() as htr_tool_tab:
273
  fn=stop_function,
274
  inputs=None,
275
  outputs=None,
276
- queue=False,
277
  api_name=False,
278
  # cancels=[htr_event_click_event],
279
  )
@@ -282,7 +280,6 @@ with gr.Blocks() as htr_tool_tab:
282
  fn=fast_track.visualize_image_viewer,
283
  inputs=fast_track_input_region_image,
284
  outputs=[fast_track_output_image, text_polygon_dict],
285
- queue=False,
286
  api_name=False,
287
  )
288
 
@@ -290,7 +287,32 @@ with gr.Blocks() as htr_tool_tab:
290
  fast_track.get_text_from_coords,
291
  inputs=text_polygon_dict,
292
  outputs=selection_text_from_image_viewer,
293
- queue=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  api_name=False,
295
  )
296
 
 
4
 
5
  from helper.examples.examples import DemoImages
6
  from helper.utils import TrafficDataHandler
7
+ from src.htr_pipeline.gradio_backend import (
8
+ FastTrack,
9
+ SingletonModelLoader,
10
+ compare_diff_runs_highlight,
11
+ compute_cer_a_and_b_with_gt,
12
+ update_selected_tab_image_viewer,
13
+ update_selected_tab_model_compare,
14
+ update_selected_tab_output_and_setting,
15
+ upload_file,
16
+ )
17
 
18
  model_loader = SingletonModelLoader()
19
  fast_track = FastTrack(model_loader)
 
64
  )
65
 
66
  with gr.Tab("Compare") as tab_model_compare_selector:
67
+ with gr.Row():
68
+ diff_runs_button = gr.Button("Compare runs", variant="primary", visible=True)
 
 
 
 
 
 
 
 
69
  calc_cer_button_fast = gr.Button("Calculate CER", variant="primary", visible=True)
70
+ with gr.Row():
71
+ cer_output_fast = gr.Textbox(
72
+ label="Character Error Rate:",
73
+ info="The percentage of characters that have been transcribed incorrectly",
74
+ )
75
 
76
  with gr.Column(scale=4):
77
  with gr.Box():
 
148
 
149
  with gr.Row():
150
  htr_tool_transcriber_model_dropdown = gr.Dropdown(
151
+ choices=[
152
+ "Riksarkivet/satrn_htr",
153
+ "microsoft/trocr-base-handwritten",
154
+ "pstroe/bullinger-general-model",
155
+ ],
156
  value="Riksarkivet/satrn_htr",
157
  label="Text recognition models",
158
  info="More models will be added",
 
177
  )
178
 
179
  with gr.Column(visible=False) as model_compare_selector:
 
180
  with gr.Row():
181
+ gr.Markdown("Compare different runs (Page XML output) with Ground Truth (GT)")
 
 
 
 
182
  with gr.Row():
183
+ with gr.Group():
184
+ upload_button_run_a = gr.UploadButton("A", file_types=[".xml"], file_count="single")
185
+ file_input_xml_run_a = gr.File(
186
+ label=None,
187
+ file_count="single",
188
+ height=100,
189
+ elem_id="download_file",
190
+ interactive=False,
191
+ visible=False,
192
+ )
193
 
194
+ with gr.Group():
195
+ upload_button_run_b = gr.UploadButton("B", file_types=[".xml"], file_count="single")
196
+ file_input_xml_run_b = gr.File(
197
+ label=None,
198
+ file_count="single",
199
+ height=100,
200
+ elem_id="download_file",
201
+ interactive=False,
202
+ visible=False,
203
+ )
204
 
205
+ with gr.Group():
206
+ upload_button_run_gt = gr.UploadButton("GT", file_types=[".xml"], file_count="single")
207
+ file_input_xml_run_gt = gr.File(
208
+ label=None,
209
+ file_count="single",
210
+ height=100,
211
+ elem_id="download_file",
212
+ interactive=False,
213
+ visible=False,
214
+ )
215
+ with gr.Tab("Comparing run A with B"):
216
+ text_diff_runs = gr.HighlightedText(
217
+ label="A with B",
218
  combine_adjacent=True,
219
  show_legend=True,
220
  color_map={"+": "red", "-": "green"},
221
  )
222
+ with gr.Tab("Compare run A with Ground Truth"):
223
+ text_diff_gt = gr.HighlightedText(
224
+ label="A with GT",
 
225
  combine_adjacent=True,
226
  show_legend=True,
227
  color_map={"+": "red", "-": "green"},
228
  )
229
 
 
 
 
 
 
 
 
230
  xml_rendered_placeholder_for_api = gr.Textbox(placeholder="XML", visible=False)
231
 
232
  htr_event_click_event = htr_pipeline_button.click(
233
  fast_track.segment_to_xml,
234
+ inputs=[fast_track_input_region_image, radio_file_input, htr_tool_transcriber_model_dropdown],
235
  outputs=[fast_file_downlod, fast_file_downlod],
 
236
  api_name=False,
237
  )
238
 
 
244
  api_name="run_htr_pipeline",
245
  )
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  tab_output_and_setting_selector.select(
248
  fn=update_selected_tab_output_and_setting,
249
  outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
 
250
  api_name=False,
251
  )
252
 
253
  tab_image_viewer_selector.select(
254
  fn=update_selected_tab_image_viewer,
255
  outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
 
256
  api_name=False,
257
  )
258
 
259
  tab_model_compare_selector.select(
260
  fn=update_selected_tab_model_compare,
261
  outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
 
262
  api_name=False,
263
  )
264
 
 
272
  fn=stop_function,
273
  inputs=None,
274
  outputs=None,
 
275
  api_name=False,
276
  # cancels=[htr_event_click_event],
277
  )
 
280
  fn=fast_track.visualize_image_viewer,
281
  inputs=fast_track_input_region_image,
282
  outputs=[fast_track_output_image, text_polygon_dict],
 
283
  api_name=False,
284
  )
285
 
 
287
  fast_track.get_text_from_coords,
288
  inputs=text_polygon_dict,
289
  outputs=selection_text_from_image_viewer,
290
+ api_name=False,
291
+ )
292
+
293
+ upload_button_run_a.upload(
294
+ upload_file, inputs=upload_button_run_a, outputs=[file_input_xml_run_a, file_input_xml_run_a], api_name=False
295
+ )
296
+
297
+ upload_button_run_b.upload(
298
+ upload_file, inputs=upload_button_run_b, outputs=[file_input_xml_run_b, file_input_xml_run_b], api_name=False
299
+ )
300
+
301
+ upload_button_run_gt.upload(
302
+ upload_file, inputs=upload_button_run_gt, outputs=[file_input_xml_run_gt, file_input_xml_run_gt], api_name=False
303
+ )
304
+
305
+ diff_runs_button.click(
306
+ fn=compare_diff_runs_highlight,
307
+ inputs=[file_input_xml_run_a, file_input_xml_run_b, file_input_xml_run_gt],
308
+ outputs=[text_diff_runs, text_diff_gt],
309
+ api_name=False,
310
+ )
311
+
312
+ calc_cer_button_fast.click(
313
+ fn=compute_cer_a_and_b_with_gt,
314
+ inputs=[file_input_xml_run_a, file_input_xml_run_b, file_input_xml_run_gt],
315
+ outputs=cer_output_fast,
316
  api_name=False,
317
  )
318
 
tabs/overview_tab.py CHANGED
@@ -4,7 +4,7 @@ from helper.text.text_overview import TextOverview
4
 
5
  with gr.Blocks() as overview:
6
  with gr.Tabs():
7
- with gr.Tab("HTRFLOW"):
8
  with gr.Row():
9
  with gr.Column():
10
  gr.Markdown(TextOverview.htrflow_col1)
@@ -56,6 +56,8 @@ with gr.Blocks() as overview:
56
  with gr.Row():
57
  with gr.Column():
58
  gr.Markdown(TextOverview.changelog)
 
 
59
  with gr.Column():
60
  gr.Markdown(TextOverview.roadmap)
61
 
 
4
 
5
  with gr.Blocks() as overview:
6
  with gr.Tabs():
7
+ with gr.Tab("About"):
8
  with gr.Row():
9
  with gr.Column():
10
  gr.Markdown(TextOverview.htrflow_col1)
 
56
  with gr.Row():
57
  with gr.Column():
58
  gr.Markdown(TextOverview.changelog)
59
+ with gr.Accordion("Previous changes", open=False):
60
+ gr.Markdown(TextOverview.old_changelog)
61
  with gr.Column():
62
  gr.Markdown(TextOverview.roadmap)
63
 
tabs/stepwise_htr_tool.py CHANGED
@@ -287,24 +287,39 @@ with gr.Blocks() as stepwise_htr_tool_tab:
287
  else:
288
  return "Ground truth not provided"
289
 
290
- calc_cer_button.click(compute_cer, inputs=[dataframe_text_index, gt_text_index], outputs=cer_output)
 
 
 
 
 
291
 
292
- calc_cer_button.click(diff_texts, inputs=[dataframe_text_index, gt_text_index], outputs=[diff_token_output])
 
 
 
 
 
293
 
294
  region_segment_button.click(
295
  custom_track.region_segment,
296
  inputs=[input_region_image, reg_pred_score_threshold_slider, reg_containments_threshold_slider],
297
  outputs=[output_region_image, regions_cropped_gallery, image_placeholder_lines, control_line_segment],
 
298
  )
299
 
300
  regions_cropped_gallery.select(
301
- custom_track.get_select_index_image, regions_cropped_gallery, input_region_from_gallery
 
 
 
302
  )
303
 
304
  transcribed_text_df_finish.select(
305
  fn=custom_track.get_select_index_df,
306
  inputs=[transcribed_text_df_finish, mapping_dict],
307
  outputs=[gallery_inputs_lines_to_transcribe, dataframe_text_index],
 
308
  )
309
 
310
  line_segment_button.click(
@@ -322,9 +337,14 @@ with gr.Blocks() as stepwise_htr_tool_tab:
322
  image_placeholder_htr,
323
  control_htr,
324
  ],
 
325
  )
326
 
327
- copy_textarea.click(fn=None, _js="""document.querySelector("#textarea_stepwise_3 > label > button").click()""")
 
 
 
 
328
 
329
  transcribe_button.click(
330
  custom_track.transcribe_text,
@@ -337,6 +357,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
337
  control_results_transcribe,
338
  image_placeholder_explore_results,
339
  ],
 
340
  )
341
 
342
  clear_button.click(
@@ -377,6 +398,7 @@ with gr.Blocks() as stepwise_htr_tool_tab:
377
  image_placeholder_explore_results,
378
  image_placeholder_lines,
379
  ],
 
380
  )
381
 
382
  SECRET_KEY = os.environ.get("AM_I_IN_A_DOCKER_CONTAINER", False)
 
287
  else:
288
  return "Ground truth not provided"
289
 
290
+ calc_cer_button.click(
291
+ compute_cer,
292
+ inputs=[dataframe_text_index, gt_text_index],
293
+ outputs=cer_output,
294
+ api_name=False,
295
+ )
296
 
297
+ calc_cer_button.click(
298
+ diff_texts,
299
+ inputs=[dataframe_text_index, gt_text_index],
300
+ outputs=[diff_token_output],
301
+ api_name=False,
302
+ )
303
 
304
  region_segment_button.click(
305
  custom_track.region_segment,
306
  inputs=[input_region_image, reg_pred_score_threshold_slider, reg_containments_threshold_slider],
307
  outputs=[output_region_image, regions_cropped_gallery, image_placeholder_lines, control_line_segment],
308
+ api_name=False,
309
  )
310
 
311
  regions_cropped_gallery.select(
312
+ custom_track.get_select_index_image,
313
+ regions_cropped_gallery,
314
+ input_region_from_gallery,
315
+ api_name=False,
316
  )
317
 
318
  transcribed_text_df_finish.select(
319
  fn=custom_track.get_select_index_df,
320
  inputs=[transcribed_text_df_finish, mapping_dict],
321
  outputs=[gallery_inputs_lines_to_transcribe, dataframe_text_index],
322
+ api_name=False,
323
  )
324
 
325
  line_segment_button.click(
 
337
  image_placeholder_htr,
338
  control_htr,
339
  ],
340
+ api_name=False,
341
  )
342
 
343
+ copy_textarea.click(
344
+ fn=None,
345
+ _js="""document.querySelector("#textarea_stepwise_3 > label > button").click()""",
346
+ api_name=False,
347
+ )
348
 
349
  transcribe_button.click(
350
  custom_track.transcribe_text,
 
357
  control_results_transcribe,
358
  image_placeholder_explore_results,
359
  ],
360
+ api_name=False,
361
  )
362
 
363
  clear_button.click(
 
398
  image_placeholder_explore_results,
399
  image_placeholder_lines,
400
  ],
401
+ api_name=False,
402
  )
403
 
404
  SECRET_KEY = os.environ.get("AM_I_IN_A_DOCKER_CONTAINER", False)