File size: 9,677 Bytes
c9019cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
#!/usr/bin/env python

# Copyright (c) 2022, National Diet Library, Japan
#
# This software is released under the CC BY 4.0.
# https://creativecommons.org/licenses/by/4.0/

import sys
import os
from pathlib import Path
from .utils import auto_run
from typing import List
import xml.etree.ElementTree as ET

import mmcv
from mmdet.apis import (inference_detector, init_detector)


def generate_class_colors(class_num):
    import cv2
    import numpy as np
    colors = 255 * np.ones((class_num, 3), dtype=np.uint8)
    colors[:, 0] = np.linspace(0, 179, class_num)
    colors = cv2.cvtColor(colors[None, ...], cv2.COLOR_HSV2BGR)[0]
    return colors


def draw_legand(img, origin, classes, colors, ssz: int = 16):
    import cv2
    c_num = len(classes)
    x, y = origin[0], origin[1]
    for c in range(c_num):
        color = colors[c]
        color = (int(color[0]), int(color[1]), int(color[2]))
        text = classes[c]
        img = cv2.rectangle(img, (x, y), (x + ssz - 1, y + ssz - 1), color, -1)
        img = cv2.putText(img, text, (x + ssz, y + ssz), cv2.FONT_HERSHEY_PLAIN,
                          1, (255, 0, 0), 1, cv2.LINE_AA)
        y += ssz
    return img


class LayoutDetector:
    def __init__(self, config: str, checkpoint: str, device: str):
        print(f'load from config={config}, checkpoint={checkpoint}')
        self.load(config, checkpoint, device)
        cfg = mmcv.Config.fromfile(config)
        self.classes = cfg.classes
        self.colors = generate_class_colors(len(self.classes))

    def load(self, config: str, checkpoint: str, device: str):
        self.model = init_detector(config, checkpoint, device)

    def predict(self, img_path: str):
        return inference_detector(self.model, img_path)

    def show(self, img_path: str, result, score_thr: float = 0.3, border: int = 3, show_legand: bool = True):
        import cv2
        img = cv2.imread(img_path)
        for c in range(len(result)):
            color = self.colors[c]
            color = (int(color[0]), int(color[1]), int(color[2]))
            for pred in result[c]:
                if float(pred[4]) < score_thr:
                    continue
                x0, y0 = int(pred[0]), int(pred[1])
                x1, y1 = int(pred[2]), int(pred[3])
                img = cv2.rectangle(img, (x0, y0), (x1, y1), color, border)

        sz = max(img.shape[0], img.shape[1])
        scale = 1024.0 / sz
        img = cv2.resize(img, dsize=None, fx=scale, fy=scale)

        if show_legand:
            ssz = 16
            c_num = len(self.classes)
            org_width = img.shape[1]
            img = cv2.copyMakeBorder(
                img, 0, 0, 0, 8 * c_num, cv2.BORDER_REPLICATE)
            x = org_width
            y = img.shape[0] - ssz * c_num
            img = draw_legand(img, (x, y),
                              self.classes, self.colors, ssz=ssz)

        return img

    def draw_rects_with_data(self, img, result, score_thr: float = 0.3, border: int = 3, show_legand: bool = True):
        import cv2
        for c in range(len(result)):
            color = self.colors[c]
            color = (int(color[0]), int(color[1]), int(color[2]))
            for pred in result[c]:
                if float(pred[4]) < score_thr:
                    continue
                x0, y0 = int(pred[0]), int(pred[1])
                x1, y1 = int(pred[2]), int(pred[3])
                img = cv2.rectangle(img, (x0, y0), (x1, y1), color, border)

        sz = max(img.shape[0], img.shape[1])
        scale = 1024.0 / sz
        img = cv2.resize(img, dsize=None, fx=scale, fy=scale)

        if show_legand:
            ssz = 16
            c_num = len(self.classes)
            org_width = img.shape[1]
            img = cv2.copyMakeBorder(
                img, 0, 0, 0, 8 * c_num, cv2.BORDER_REPLICATE)
            x = org_width
            y = img.shape[0] - ssz * c_num
            img = draw_legand(img, (x, y), self.classes, self.colors, ssz=ssz)

        return img


def convert_to_xml_string(img_path, classes, result, score_thr: float = 0.3):
    import cv2
    img = cv2.imread(img_path)
    img_h, img_w = img.shape[0:2]

    from .ndl_parser import name_to_org_name

    img_name = os.path.basename(img_path)
    s = f'<PAGE IMAGENAME = "{img_name}" WIDTH = "{img_w}" HEIGHT = "{img_h}">\n'

    for c in range(len(classes)):
        cls = classes[c]
        if cls.startswith('line_'):
            for line in result[c]:
                conf = float(line[4])
                if conf < score_thr:
                    continue
                x, y = int(line[0]), int(line[1])
                w, h = int(line[2] - line[0]), int(line[3] - line[1])
                s += f'<LINE TYPE = "{name_to_org_name(cls)}" X = "{x}" Y = "{y}" WIDTH = "{w}" HEIGHT = "{h}" CONF = "{conf:0.3f}"></LINE>\n'
        elif cls.startswith('block_'):
            for block in result[c]:
                conf = float(block[4])
                if conf < score_thr:
                    continue
                x, y = int(block[0]), int(block[1])
                w, h = int(block[2] - block[0]), int(block[3] - block[1])
                s += f'<BLOCK TYPE = "{name_to_org_name(cls)}" X = "{x}" Y = "{y}" WIDTH = "{w}" HEIGHT = "{h}" CONF = "{conf:0.3f}"></BLOCK>\n'

    s += '</PAGE>\n'
    return s


def convert_to_xml_string_with_data(img, img_path, classes, result, score_thr: float = 0.3):
    img_h, img_w = img.shape[0:2]

    from .ndl_parser import name_to_org_name

    img_name = os.path.basename(img_path)
    s = f'<PAGE IMAGENAME = "{img_name}" WIDTH = "{img_w}" HEIGHT = "{img_h}">\n'

    for c in range(len(classes)):
        cls = classes[c]
        if cls.startswith('line_'):
            for line in result[c]:
                conf = float(line[4])
                if conf < score_thr:
                    continue
                x, y = int(line[0]), int(line[1])
                w, h = int(line[2] - line[0]), int(line[3] - line[1])
                s += f'<LINE TYPE = "{name_to_org_name(cls)}" X = "{x}" Y = "{y}" WIDTH = "{w}" HEIGHT = "{h}" CONF = "{conf:0.3f}"></LINE>\n'
        elif cls.startswith('block_'):
            for block in result[c]:
                conf = float(block[4])
                if conf < score_thr:
                    continue
                x, y = int(block[0]), int(block[1])
                w, h = int(block[2] - block[0]), int(block[3] - block[1])
                s += f'<BLOCK TYPE = "{name_to_org_name(cls)}" X = "{x}" Y = "{y}" WIDTH = "{w}" HEIGHT = "{h}" CONF = "{conf:0.3f}"></BLOCK>\n'

    s += '</PAGE>\n'
    return s


def run_layout_detection(img_paths: List[str] = None, list_path: str = None, output_path: str = "layout_prediction.xml",
                         config: str = './models/config_file.py',
                         checkpoint: str = './models/trained_weights.pth',
                         device: str = 'cuda:0', score_thr: float = 0.3, use_show: bool = False, dump_dir: str = None):
    detector = LayoutDetector(config, checkpoint, device)
    if list_path is not None:
        img_paths = list([s.strip() for s in open(list_path).readlines()])
    if img_paths is None:
        print('Please specify --img_paths or --list_path')
        return -1
    if dump_dir is not None:
        Path(dump_dir).mkdir(exist_ok=True)
    with open(output_path, 'w') as f:
        def tee(s):
            print(s, file=f, end="")
            print(s, file=sys.stdout, end="")

        tee('<?xml version="1.0" encoding="utf-8" standalone="yes"?><OCRDATASET xmlns="">\n')

        for img_path in img_paths:
            result = detector.predict(img_path)
            xml_str = convert_to_xml_string(
                img_path, detector.classes, result, score_thr=score_thr)
            tee(xml_str)

            if use_show:
                import cv2
                img = detector.show(img_path, result, score_thr=score_thr)
                cv2.namedWindow('show')
                cv2.imshow('show', img)
                if 27 == cv2.waitKey(0):
                    break

            if dump_dir is not None:
                import cv2
                img = detector.show(img_path, result, score_thr=score_thr)
                cv2.imwrite(str(Path(dump_dir) / Path(img_path).name), img)

        tee('</OCRDATASET>\n')


class InferencerWithCLI:
    def __init__(self, conf_dict):
        config = conf_dict['config_path']
        checkpoint = conf_dict['checkpoint_path']
        device = conf_dict['device']
        self.detector = LayoutDetector(config, checkpoint, device)

    def inference_wich_cli(self, img=None, img_path='',
                           score_thr: float = 0.3, dump: bool = False):
        ET.register_namespace('', 'NDLOCRDATASET')
        node = ET.fromstring(
            '<?xml version="1.0" encoding="utf-8" standalone="yes"?><OCRDATASET xmlns="">\n</OCRDATASET>\n')

        # prediction
        if self.detector is None:
            print('ERROR: Layout detector is not created.')
            return None
        result = self.detector.predict(img)

        # xml creation
        xml_str = convert_to_xml_string_with_data(
            img, img_path, self.detector.classes, result, score_thr=score_thr)

        # xml conversion from string
        result_xml = ET.fromstring(xml_str)
        node.append(result_xml)
        tree = ET.ElementTree(node)

        # xml conversion from string
        dump_img = None
        if dump is not None:
            dump_img = self.detector.draw_rects_with_data(
                img, result, score_thr=score_thr)

        return {'xml': tree, 'dump_img': dump_img}


if __name__ == '__main__':
    auto_run(run_layout_detection)