htr_demo / src /htr_pipeline /utils /pipeline_inferencer.py
Gabriel's picture
0.0.3 release with Trocr and compare support
c60ebd1
raw
history blame contribute delete
No virus
5.34 kB
import gradio as gr
from tqdm import tqdm
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
from src.htr_pipeline.utils.xml_helper import XMLHelper
terminate = False
# TODO check why region is so slow to start.. Is their error with loading the model?
class PipelineInferencer:
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
self.process_seg_mask = process_seg_mask
self.xml_helper = xml_helper
def image_to_page_xml(
self,
image,
htr_tool_transcriber_model_dropdown,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
inferencer,
):
# temporary solutions.. for trocr..
self.htr_tool_transcriber_model_dropdown = htr_tool_transcriber_model_dropdown
template_data = self.xml_helper.prepare_template_data(self.xml_helper.xml_file_name, image)
template_data["textRegions"] = self._process_regions(
image, inferencer, pred_score_threshold_regions, pred_score_threshold_lines, containments_threshold
)
return self.xml_helper.render(template_data)
def _process_regions(
self,
image,
inferencer,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
htr_threshold=0.6,
):
global terminate
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
image,
pred_score_threshold=pred_score_threshold_regions,
containments_threshold=containments_threshold,
visualize=False,
)
gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
region_data_list = []
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
if terminate:
break
region_data = self._create_region_data(
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
)
if region_data:
region_data_list.append(region_data)
return region_data_list
def _create_region_data(
self, data, index, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
):
text_region, reg_pol, mask = data
region_data = {"id": f"region_{index}", "boundary": reg_pol}
text_lines, htr_scores = self._process_lines(
text_region,
inferencer,
pred_score_threshold_lines,
containments_threshold,
mask,
region_data["id"],
htr_threshold,
)
if not text_lines:
return None
region_data["textLines"] = text_lines
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
return region_data if mean_htr_score > htr_threshold + 0.1 else None
def _process_lines(
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.6
):
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
)
if not lines_cropped_ordered:
return None, []
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
text_lines = []
htr_scores = []
id_number = region_id.split("_")[1]
total_lines_len = len(lines_cropped_ordered)
gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")
global terminate
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
if terminate:
break
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
if line_data:
text_lines.append(line_data)
htr_scores.append(htr_score)
remaining_lines = total_lines_len - index - 1
if (index + 1) % 10 == 0 and remaining_lines > 5: # +1 because index starts at 0
gr.Info(
f"Region {id_number}, parsed {index + 1} lines. Still {remaining_lines} lines left to transcribe."
)
return text_lines, htr_scores
def _create_line_data(self, line, line_pol, index, region_id, inferencer, htr_threshold):
line_data = {"id": f"line_{region_id}_{index}", "boundary": line_pol}
# temporary solution..
if self.htr_tool_transcriber_model_dropdown == "Riksarkivet/satrn_htr":
transcribed_text, htr_score = inferencer.transcribe(line)
else:
transcribed_text, htr_score = inferencer.transcribe_different_model(
line, self.htr_tool_transcriber_model_dropdown
)
line_data["unicode"] = self.xml_helper.escape_xml_chars(transcribed_text)
line_data["pred_score"] = round(htr_score, 4)
return line_data if htr_score > htr_threshold else None, htr_score
if __name__ == "__main__":
pass