File size: 2,553 Bytes
5ebeb73
 
 
 
 
 
 
20103df
a36a5bd
5ebeb73
7263d32
 
 
5ebeb73
 
 
 
 
 
7263d32
5ebeb73
 
 
 
 
c60ebd1
5ebeb73
 
 
 
 
 
 
7263d32
c60ebd1
 
 
 
 
 
5ebeb73
 
 
 
 
 
7263d32
5ebeb73
089249c
 
5ebeb73
 
 
7263d32
 
5ebeb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Protocol

import mmcv
import numpy as np

from src.htr_pipeline.inferencer import Inferencer
from src.htr_pipeline.utils.helper import timer_func
from src.htr_pipeline.utils.parser_xml import XmlParser
from src.htr_pipeline.utils.pipeline_inferencer import PipelineInferencer
from src.htr_pipeline.utils.preprocess_img import Preprocess
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
from src.htr_pipeline.utils.visualize_xml import XmlViz
from src.htr_pipeline.utils.xml_helper import XMLHelper


class Pipeline:
    def __init__(self, inferencer: Inferencer) -> None:
        self.inferencer = inferencer
        self.preprocess_img = Preprocess()
        self.pipeline_inferencer = PipelineInferencer(SegMaskHelper(), XMLHelper())

    @timer_func
    def running_htr_pipeline(
        self,
        input_image: np.ndarray,
        htr_tool_transcriber_model_dropdown,
        pred_score_threshold_regions: float = 0.4,
        pred_score_threshold_lines: float = 0.4,
        containments_threshold: float = 0.5,
    ) -> str:
        input_image = self.preprocess_img.binarize_img(input_image)
        image = mmcv.imread(input_image)

        rendered_xml = self.pipeline_inferencer.image_to_page_xml(
            image,
            htr_tool_transcriber_model_dropdown,
            pred_score_threshold_regions,
            pred_score_threshold_lines,
            containments_threshold,
            self.inferencer,
        )

        return rendered_xml

    @timer_func
    def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
        xml_viz = XmlViz()
        bin_input_image = self.preprocess_img.binarize_img(input_image)
        xml_image, text_polygon_dict = xml_viz.visualize_xml(bin_input_image)
        return xml_image, text_polygon_dict

    @timer_func
    def parse_xml_to_txt(self) -> None:
        xml_visualizer_and_parser = XmlParser()
        xml_visualizer_and_parser.xml_to_txt()


class PipelineInterface(Protocol):
    def __init__(self, inferencer: Inferencer) -> None:
        ...

    def running_htr_pipeline(
        self,
        input_image: np.ndarray,
        pred_score_threshold_regions: float = 0.4,
        pred_score_threshold_lines: float = 0.4,
        containments_threshold: float = 0.5,
    ) -> str:
        ...

    def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
        ...

    def parse_xml_to_txt(self) -> None:
        ...


if __name__ == "__main__":
    prediction_model = Inferencer()
    pipeline = Pipeline(prediction_model)