htr_demo / src /htr_pipeline /utils /process_xml.py
Gabriel's picture
added some chnages for xml
4c85050
raw
history blame
No virus
5.28 kB
import os
import re
from datetime import datetime
import jinja2
from tqdm import tqdm
from src.htr_pipeline.inferencer import InferencerInterface
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
class XMLHelper:
def __init__(self):
self.process_seg_mask = SegMaskHelper()
def image_to_page_xml(
self,
image,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
inferencer: InferencerInterface,
xml_file_name="page_xml.xml",
):
img_height = image.shape[0]
img_width = image.shape[1]
img_file_name = xml_file_name
template_data = self.prepare_template_data(img_file_name, img_width, img_height)
template_data["textRegions"] = self._process_regions(
image,
inferencer,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
)
rendered_xml = self._render_xml(template_data)
return rendered_xml
def _transform_coords(self, input_string):
pattern = r"\[\s*([^\s,]+)\s*,\s*([^\s\]]+)\s*\]"
replacement = r"\1,\2"
return re.sub(pattern, replacement, input_string)
def _render_xml(self, template_data):
template_loader = jinja2.FileSystemLoader(searchpath="./src/htr_pipeline/utils/templates")
template_env = jinja2.Environment(loader=template_loader, trim_blocks=True)
template = template_env.get_template("page_xml_2013.xml")
rendered_xml = template.render(template_data)
rendered_xml = self._transform_coords(rendered_xml)
return rendered_xml
def prepare_template_data(self, img_file_name, img_width, img_height):
now = datetime.now()
date_time = now.strftime("%Y-%m-%d, %H:%M:%S")
return {
"created": date_time,
"imageFilename": img_file_name,
"imageWidth": img_width,
"imageHeight": img_height,
"textRegions": list(),
}
def _process_regions(
self,
image,
inferencer: InferencerInterface,
pred_score_threshold_regions,
pred_score_threshold_lines,
containments_threshold,
htr_threshold=0.7,
):
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
image,
pred_score_threshold=pred_score_threshold_regions,
containments_threshold=containments_threshold,
visualize=False,
)
region_data_list = []
for i, (text_region, reg_pol, mask) in tqdm(
enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))
):
region_id = "region_" + str(i)
region_data = dict()
region_data["id"] = region_id
region_data["boundary"] = reg_pol
text_lines, htr_scores = self._process_lines(
text_region,
inferencer,
pred_score_threshold_lines,
containments_threshold,
mask,
region_id,
)
if text_lines is None:
continue
region_data["textLines"] = text_lines
mean_htr_score = sum(htr_scores) / len(htr_scores)
if mean_htr_score > htr_threshold:
region_data_list.append(region_data)
return region_data_list
def _process_lines(
self,
text_region,
inferencer: InferencerInterface,
pred_score_threshold_lines,
containments_threshold,
mask,
region_id,
htr_threshold=0.7,
):
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
text_region,
pred_score_threshold=pred_score_threshold_lines,
containments_threshold=containments_threshold,
visualize=False,
custom_track=False,
)
if lines_cropped_ordered is None:
return None, None
line_polygons_ordered_trans = self.process_seg_mask._translate_line_coords(mask, line_polygons_ordered)
htr_scores = list()
text_lines = list()
for j, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
line_id = "line_" + region_id + "_" + str(j)
line_data = dict()
line_data["id"] = line_id
line_data["boundary"] = line_pol
transcribed_text, htr_score = inferencer.transcribe(line)
escaped_text = self._escape_xml_chars(transcribed_text)
line_data["unicode"] = escaped_text
line_data["pred_score"] = round(htr_score, 4)
htr_scores.append(htr_score)
if htr_score > htr_threshold:
text_lines.append(line_data)
return text_lines, htr_scores
def _escape_xml_chars(self, textline):
return (
textline.replace("&", "&")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace("'", "&apos;")
.replace('"', "&quot;")
)
if __name__ == "__main__":
pass