Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
from difflib import Differ | |
import evaluate | |
import gradio as gr | |
from helper.examples.examples import DemoImages | |
from helper.utils import TrafficDataHandler | |
from src.htr_pipeline.gradio_backend import CustomTrack, SingletonModelLoader | |
model_loader = SingletonModelLoader() | |
custom_track = CustomTrack(model_loader) | |
images_for_demo = DemoImages() | |
cer_metric = evaluate.load("cer") | |
with gr.Blocks() as stepwise_htr_tool_tab: | |
with gr.Tabs(): | |
with gr.Tab("1. Region segmentation"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
vis_data_folder_placeholder = gr.Markdown(visible=False) | |
name_files_placeholder = gr.Markdown(visible=False) | |
with gr.Group(): | |
input_region_image = gr.Image( | |
label="Image to region segment", | |
# type="numpy", | |
tool="editor", | |
height=500, | |
) | |
with gr.Accordion("Settings", open=False): | |
with gr.Group(): | |
reg_pred_score_threshold_slider = gr.Slider( | |
minimum=0.4, | |
maximum=1, | |
value=0.5, | |
step=0.05, | |
label="P-threshold", | |
info="""Filter the confidence score for a prediction score to be considered""", | |
) | |
reg_containments_threshold_slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.05, | |
label="C-threshold", | |
info="""The minimum required overlap or similarity | |
for a detected region or object to be considered valid""", | |
) | |
region_segment_model_dropdown = gr.Dropdown( | |
choices=["Riksarkivet/rtm_region"], | |
value="Riksarkivet/rtm_region", | |
label="Region segmentation model", | |
info="More models will be added", | |
) | |
with gr.Row(): | |
clear_button = gr.Button("Clear", variant="secondary", elem_id="clear_button") | |
region_segment_button = gr.Button( | |
"Run", | |
variant="primary", | |
elem_id="region_segment_button", | |
) | |
region_segment_button_var = gr.State(value="region_segment_button") | |
with gr.Column(scale=2): | |
with gr.Box(): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Examples( | |
examples=images_for_demo.examples_list, | |
inputs=[name_files_placeholder, input_region_image], | |
label="Example images", | |
examples_per_page=5, | |
) | |
with gr.Column(scale=3): | |
output_region_image = gr.Image(label="Segmented regions", type="numpy", height=600) | |
############################################## | |
with gr.Tab("2. Line segmentation"): | |
image_placeholder_lines = gr.Image( | |
label="Segmented lines", | |
# type="numpy", | |
interactive="False", | |
visible=True, | |
height=600, | |
) | |
with gr.Row(visible=False) as control_line_segment: | |
with gr.Column(scale=2): | |
with gr.Group(): | |
with gr.Box(): | |
regions_cropped_gallery = gr.Gallery( | |
label="Segmented regions", | |
elem_id="gallery", | |
columns=[2], | |
rows=[2], | |
# object_fit="contain", | |
height=450, | |
preview=True, | |
container=False, | |
) | |
input_region_from_gallery = gr.Image( | |
label="Region segmentation to line segment", interactive="False", visible=False, height=400 | |
) | |
with gr.Row(): | |
with gr.Accordion("Settings", open=False): | |
with gr.Row(): | |
line_pred_score_threshold_slider = gr.Slider( | |
minimum=0.3, | |
maximum=1, | |
value=0.4, | |
step=0.05, | |
label="Pred_score threshold", | |
info="""Filter the confidence score for a prediction score to be considered""", | |
) | |
line_containments_threshold_slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.05, | |
label="Containments threshold", | |
info="""The minimum required overlap or similarity | |
for a detected region or object to be considered valid""", | |
) | |
with gr.Row(equal_height=False): | |
line_segment_model_dropdown = gr.Dropdown( | |
choices=["Riksarkivet/rtmdet_lines"], | |
value="Riksarkivet/rtmdet_lines", | |
label="Line segment model", | |
info="More models will be added", | |
) | |
with gr.Row(): | |
# placeholder_line_button = gr.Button( | |
# "", | |
# variant="secondary", | |
# scale=1, | |
# ) | |
gr.Markdown(" ") | |
line_segment_button = gr.Button( | |
"Run", | |
variant="primary", | |
# elem_id="center_button", | |
scale=1, | |
) | |
with gr.Column(scale=3): | |
output_line_from_region = gr.Image( | |
label="Segmented lines", type="numpy", interactive="False", height=600 | |
) | |
############################################### | |
with gr.Tab("3. Text recognition"): | |
image_placeholder_htr = gr.Image( | |
label="Transcribed lines", | |
# type="numpy", | |
interactive="False", | |
visible=True, | |
height=600, | |
) | |
with gr.Row(visible=False) as control_htr: | |
inputs_lines_to_transcribe = gr.Variable() | |
with gr.Column(scale=2): | |
with gr.Group(): | |
image_inputs_lines_to_transcribe = gr.Image( | |
label="Transcribed lines", type="numpy", interactive="False", visible=False, height=470 | |
) | |
with gr.Row(): | |
with gr.Accordion("Settings", open=False): | |
transcriber_model = gr.Dropdown( | |
choices=["Riksarkivet/satrn_htr", "microsoft/trocr-base-handwritten"], | |
value="Riksarkivet/satrn_htr", | |
label="Text recognition model", | |
info="More models will be added", | |
) | |
gr.Slider( | |
value=0.6, | |
minimum=0.5, | |
maximum=1, | |
label="HTR threshold", | |
info="Prediction score threshold for transcribed lines", | |
scale=1, | |
) | |
with gr.Row(): | |
copy_textarea = gr.Button("Copy text", variant="secondary", visible=True, scale=1) | |
transcribe_button = gr.Button("Run", variant="primary", visible=True, scale=1) | |
with gr.Column(scale=3): | |
with gr.Row(): | |
transcribed_text = gr.Textbox( | |
label="Transcribed text", | |
info="Transcribed text is being streamed back from the Text recognition model", | |
lines=26, | |
value="", | |
show_copy_button=True, | |
elem_id="textarea_stepwise_3", | |
) | |
##################################### | |
with gr.Tab("4. Explore results"): | |
image_placeholder_explore_results = gr.Image( | |
label="Cropped transcribed lines", | |
# type="numpy", | |
interactive="False", | |
visible=True, | |
height=600, | |
) | |
with gr.Row(visible=False, equal_height=False) as control_results_transcribe: | |
with gr.Column(scale=1, visible=True): | |
with gr.Group(): | |
with gr.Box(): | |
temp_gallery_input = gr.Variable() | |
gallery_inputs_lines_to_transcribe = gr.Gallery( | |
label="Cropped transcribed lines", | |
elem_id="gallery_lines", | |
columns=[3], | |
rows=[3], | |
# object_fit="contain", | |
height=150, | |
preview=True, | |
container=False, | |
) | |
with gr.Row(): | |
dataframe_text_index = gr.Textbox( | |
label="Text from DataFrame selection", | |
placeholder="Select row from the DataFrame.", | |
interactive=False, | |
) | |
with gr.Row(): | |
gt_text_index = gr.Textbox( | |
label="Ground Truth", | |
placeholder="Provide the ground truth, if available.", | |
interactive=True, | |
) | |
with gr.Row(): | |
diff_token_output = gr.HighlightedText( | |
label="Text diff", | |
combine_adjacent=True, | |
show_legend=True, | |
color_map={"+": "red", "-": "green"}, | |
) | |
with gr.Row(equal_height=False): | |
cer_output = gr.Textbox(label="Character Error Rate") | |
gr.Markdown("") | |
calc_cer_button = gr.Button("Calculate CER", variant="primary", visible=True) | |
with gr.Column(scale=1, visible=True): | |
mapping_dict = gr.Variable() | |
transcribed_text_df_finish = gr.Dataframe( | |
headers=["Transcribed text", "Prediction score"], | |
max_rows=14, | |
col_count=(2, "fixed"), | |
wrap=True, | |
interactive=False, | |
overflow_row_behaviour="paginate", | |
height=600, | |
) | |
# custom track | |
def diff_texts(text1, text2): | |
d = Differ() | |
return [(token[2:], token[0] if token[0] != " " else None) for token in d.compare(text1, text2)] | |
def compute_cer(dataframe_text_index, gt_text_index): | |
if gt_text_index is not None and gt_text_index.strip() != "": | |
return round(cer_metric.compute(predictions=[dataframe_text_index], references=[gt_text_index]), 4) | |
else: | |
return "Ground truth not provided" | |
calc_cer_button.click(compute_cer, inputs=[dataframe_text_index, gt_text_index], outputs=cer_output) | |
calc_cer_button.click(diff_texts, inputs=[dataframe_text_index, gt_text_index], outputs=[diff_token_output]) | |
region_segment_button.click( | |
custom_track.region_segment, | |
inputs=[input_region_image, reg_pred_score_threshold_slider, reg_containments_threshold_slider], | |
outputs=[output_region_image, regions_cropped_gallery, image_placeholder_lines, control_line_segment], | |
) | |
regions_cropped_gallery.select( | |
custom_track.get_select_index_image, regions_cropped_gallery, input_region_from_gallery | |
) | |
transcribed_text_df_finish.select( | |
fn=custom_track.get_select_index_df, | |
inputs=[transcribed_text_df_finish, mapping_dict], | |
outputs=[gallery_inputs_lines_to_transcribe, dataframe_text_index], | |
) | |
line_segment_button.click( | |
custom_track.line_segment, | |
inputs=[input_region_from_gallery, line_pred_score_threshold_slider, line_containments_threshold_slider], | |
outputs=[ | |
output_line_from_region, | |
image_inputs_lines_to_transcribe, | |
inputs_lines_to_transcribe, | |
gallery_inputs_lines_to_transcribe, | |
temp_gallery_input, | |
# Hide | |
transcribe_button, | |
image_inputs_lines_to_transcribe, | |
image_placeholder_htr, | |
control_htr, | |
], | |
) | |
copy_textarea.click(fn=None, _js="""document.querySelector("#textarea_stepwise_3 > label > button").click()""") | |
transcribe_button.click( | |
custom_track.transcribe_text, | |
inputs=[inputs_lines_to_transcribe], | |
outputs=[ | |
transcribed_text, | |
transcribed_text_df_finish, | |
mapping_dict, | |
# Hide | |
control_results_transcribe, | |
image_placeholder_explore_results, | |
], | |
) | |
clear_button.click( | |
lambda: ( | |
(shutil.rmtree("./vis_data") if os.path.exists("./vis_data") else None, None)[1], | |
None, | |
None, | |
None, | |
gr.update(visible=False), | |
None, | |
None, | |
None, | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
None, | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
), | |
inputs=[], | |
outputs=[ | |
vis_data_folder_placeholder, | |
input_region_image, | |
regions_cropped_gallery, | |
input_region_from_gallery, | |
control_line_segment, | |
output_line_from_region, | |
inputs_lines_to_transcribe, | |
transcribed_text, | |
control_htr, | |
inputs_lines_to_transcribe, | |
image_placeholder_htr, | |
output_region_image, | |
image_inputs_lines_to_transcribe, | |
control_results_transcribe, | |
image_placeholder_explore_results, | |
image_placeholder_lines, | |
], | |
) | |
SECRET_KEY = os.environ.get("AM_I_IN_A_DOCKER_CONTAINER", False) | |
if SECRET_KEY: | |
region_segment_button.click(fn=TrafficDataHandler.store_metric_data, inputs=region_segment_button_var) | |