htr_demo / tabs /htr_tool.py
Gabriel's picture
0.0.3 release with Trocr and compare support
c60ebd1
raw history blame
No virus
14.3 kB
import os
import gradio as gr
from helper.examples.examples import DemoImages
from helper.utils import TrafficDataHandler
from src.htr_pipeline.gradio_backend import (
FastTrack,
SingletonModelLoader,
compare_diff_runs_highlight,
compute_cer_a_and_b_with_gt,
update_selected_tab_image_viewer,
update_selected_tab_model_compare,
update_selected_tab_output_and_setting,
upload_file,
)
model_loader = SingletonModelLoader()
fast_track = FastTrack(model_loader)
images_for_demo = DemoImages()
terminate = False
with gr.Blocks() as htr_tool_tab:
with gr.Row(equal_height=True):
with gr.Column(scale=2):
with gr.Row():
fast_track_input_region_image = gr.Image(
label="Image to run HTR on", type="numpy", tool="editor", elem_id="image_upload", height=395
)
with gr.Row():
with gr.Tab("HTRFLOW") as tab_output_and_setting_selector:
with gr.Row():
stop_htr_button = gr.Button(
value="Stop run",
variant="stop",
)
htr_pipeline_button = gr.Button(
"Run ",
variant="primary",
visible=True,
elem_id="run_pipeline_button",
)
htr_pipeline_button_var = gr.State(value="htr_pipeline_button")
htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
fast_file_downlod = gr.File(
label="Download output file", visible=True, scale=1, height=100, elem_id="download_file"
)
with gr.Tab("Visualize") as tab_image_viewer_selector:
with gr.Row():
gr.Markdown("")
run_image_visualizer_button = gr.Button(
value="Visualize results", variant="primary", interactive=True
)
selection_text_from_image_viewer = gr.Textbox(
interactive=False, label="Text Selector", info="Select a line on Image Viewer to return text"
)
with gr.Tab("Compare") as tab_model_compare_selector:
with gr.Row():
diff_runs_button = gr.Button("Compare runs", variant="primary", visible=True)
calc_cer_button_fast = gr.Button("Calculate CER", variant="primary", visible=True)
with gr.Row():
cer_output_fast = gr.Textbox(
label="Character Error Rate:",
info="The percentage of characters that have been transcribed incorrectly",
)
with gr.Column(scale=4):
with gr.Box():
with gr.Row(visible=True) as output_and_setting_tab:
with gr.Column(scale=2):
fast_name_files_placeholder = gr.Markdown(visible=False)
gr.Examples(
examples=images_for_demo.examples_list,
inputs=[fast_name_files_placeholder, fast_track_input_region_image],
label="Example images",
examples_per_page=5,
)
gr.Markdown(" ")
with gr.Column(scale=3):
with gr.Group():
gr.Markdown("   ⚙️ Settings ")
with gr.Row():
radio_file_input = gr.CheckboxGroup(
choices=["Txt", "Page XML"],
value=["Txt", "Page XML"],
label="Output file extension",
info="JSON and ALTO-XML will be added",
scale=1,
)
with gr.Row():
gr.Checkbox(
value=True,
label="Binarize image",
info="Binarize image to reduce background noise",
)
gr.Checkbox(
value=True,
label="Output prediction threshold",
info="Output XML with prediction score",
)
with gr.Accordion("Advanced settings", open=False):
with gr.Group():
with gr.Row():
htr_tool_region_segment_model_dropdown = gr.Dropdown(
choices=["Riksarkivet/rtmdet_region"],
value="Riksarkivet/rtmdet_region",
label="Region segmentation models",
info="More models will be added",
)
gr.Slider(
minimum=0.4,
maximum=1,
value=0.5,
step=0.05,
label="P-threshold",
info="""Filter confidence score for a prediction score to be considered""",
)
with gr.Row():
htr_tool_line_segment_model_dropdown = gr.Dropdown(
choices=["Riksarkivet/rtmdet_lines"],
value="Riksarkivet/rtmdet_lines",
label="Line segmentation models",
info="More models will be added",
)
gr.Slider(
minimum=0.4,
maximum=1,
value=0.5,
step=0.05,
label="P-threshold",
info="""Filter confidence score for a prediction score to be considered""",
)
with gr.Row():
htr_tool_transcriber_model_dropdown = gr.Dropdown(
choices=[
"Riksarkivet/satrn_htr",
"microsoft/trocr-base-handwritten",
"pstroe/bullinger-general-model",
],
value="Riksarkivet/satrn_htr",
label="Text recognition models",
info="More models will be added",
)
gr.Slider(
value=0.6,
minimum=0.5,
maximum=1,
label="HTR threshold",
info="Prediction score threshold for transcribed lines",
scale=1,
)
with gr.Row():
gr.Markdown("   More settings will be added")
with gr.Row(visible=False) as image_viewer_tab:
text_polygon_dict = gr.Variable()
fast_track_output_image = gr.Image(
label="Image Viewer", type="numpy", height=600, interactive=False
)
with gr.Column(visible=False) as model_compare_selector:
with gr.Row():
gr.Markdown("Compare different runs (Page XML output) with Ground Truth (GT)")
with gr.Row():
with gr.Group():
upload_button_run_a = gr.UploadButton("A", file_types=[".xml"], file_count="single")
file_input_xml_run_a = gr.File(
label=None,
file_count="single",
height=100,
elem_id="download_file",
interactive=False,
visible=False,
)
with gr.Group():
upload_button_run_b = gr.UploadButton("B", file_types=[".xml"], file_count="single")
file_input_xml_run_b = gr.File(
label=None,
file_count="single",
height=100,
elem_id="download_file",
interactive=False,
visible=False,
)
with gr.Group():
upload_button_run_gt = gr.UploadButton("GT", file_types=[".xml"], file_count="single")
file_input_xml_run_gt = gr.File(
label=None,
file_count="single",
height=100,
elem_id="download_file",
interactive=False,
visible=False,
)
with gr.Tab("Comparing run A with B"):
text_diff_runs = gr.HighlightedText(
label="A with B",
combine_adjacent=True,
show_legend=True,
color_map={"+": "red", "-": "green"},
)
with gr.Tab("Compare run A with Ground Truth"):
text_diff_gt = gr.HighlightedText(
label="A with GT",
combine_adjacent=True,
show_legend=True,
color_map={"+": "red", "-": "green"},
)
xml_rendered_placeholder_for_api = gr.Textbox(placeholder="XML", visible=False)
htr_event_click_event = htr_pipeline_button.click(
fast_track.segment_to_xml,
inputs=[fast_track_input_region_image, radio_file_input, htr_tool_transcriber_model_dropdown],
outputs=[fast_file_downlod, fast_file_downlod],
api_name=False,
)
htr_pipeline_button_api.click(
fast_track.segment_to_xml_api,
inputs=[fast_track_input_region_image],
outputs=[xml_rendered_placeholder_for_api],
queue=False,
api_name="run_htr_pipeline",
)
tab_output_and_setting_selector.select(
fn=update_selected_tab_output_and_setting,
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
api_name=False,
)
tab_image_viewer_selector.select(
fn=update_selected_tab_image_viewer,
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
api_name=False,
)
tab_model_compare_selector.select(
fn=update_selected_tab_model_compare,
outputs=[output_and_setting_tab, image_viewer_tab, model_compare_selector],
api_name=False,
)
def stop_function():
from src.htr_pipeline.utils import pipeline_inferencer
pipeline_inferencer.terminate = True
gr.Info("The HTR execution was halted")
stop_htr_button.click(
fn=stop_function,
inputs=None,
outputs=None,
api_name=False,
# cancels=[htr_event_click_event],
)
run_image_visualizer_button.click(
fn=fast_track.visualize_image_viewer,
inputs=fast_track_input_region_image,
outputs=[fast_track_output_image, text_polygon_dict],
api_name=False,
)
fast_track_output_image.select(
fast_track.get_text_from_coords,
inputs=text_polygon_dict,
outputs=selection_text_from_image_viewer,
api_name=False,
)
upload_button_run_a.upload(
upload_file, inputs=upload_button_run_a, outputs=[file_input_xml_run_a, file_input_xml_run_a], api_name=False
)
upload_button_run_b.upload(
upload_file, inputs=upload_button_run_b, outputs=[file_input_xml_run_b, file_input_xml_run_b], api_name=False
)
upload_button_run_gt.upload(
upload_file, inputs=upload_button_run_gt, outputs=[file_input_xml_run_gt, file_input_xml_run_gt], api_name=False
)
diff_runs_button.click(
fn=compare_diff_runs_highlight,
inputs=[file_input_xml_run_a, file_input_xml_run_b, file_input_xml_run_gt],
outputs=[text_diff_runs, text_diff_gt],
api_name=False,
)
calc_cer_button_fast.click(
fn=compute_cer_a_and_b_with_gt,
inputs=[file_input_xml_run_a, file_input_xml_run_b, file_input_xml_run_gt],
outputs=cer_output_fast,
api_name=False,
)
SECRET_KEY = os.environ.get("HUB_TOKEN", False)
if SECRET_KEY:
htr_pipeline_button.click(
fn=TrafficDataHandler.store_metric_data,
inputs=htr_pipeline_button_var,
)