File size: 2,653 Bytes
d10a366
 
4f6e58b
d10a366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6e58b
d10a366
 
 
 
 
 
 
 
 
 
4f6e58b
d10a366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6e58b
 
 
 
 
 
 
 
d10a366
4f6e58b
d10a366
4f6e58b
 
d10a366
4f6e58b
 
 
 
d10a366
4f6e58b
 
 
 
d10a366
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os.path
from functools import lru_cache
from typing import List, Tuple

import cv2
import numpy as np
from huggingface_hub import HfApi, HfFileSystem, hf_hub_download
from imgutils.data import ImageTyping
from imgutils.utils import open_onnx_model

hf_client = HfApi()
hf_fs = HfFileSystem()


@lru_cache()
def _get_available_models():
    for f in hf_fs.glob('deepghs/text_detection/*/end2end.onnx'):
        yield os.path.relpath(f, 'deepghs/text_detection').split('/')[0]


_ALL_MODELS = list(_get_available_models())
_DEFAULT_MODEL = 'dbnetpp_resnet50_fpnc_1200e_icdar2015'


@lru_cache()
def _get_onnx_session(model):
    return open_onnx_model(hf_hub_download(
        'deepghs/text_detection',
        f'{model}/end2end.onnx'
    ))


def _get_heatmap_of_text(image: ImageTyping, model: str) -> np.ndarray:
    origin_width, origin_height = width, height = image.size
    align = 32
    if width % align != 0:
        width += (align - width % align)
    if height % align != 0:
        height += (align - height % align)

    input_ = np.array(image).transpose((2, 0, 1)).astype(np.float32) / 255.0
    # noinspection PyTypeChecker
    input_ = np.pad(input_[None, ...], ((0, 0), (0, 0), (0, height - origin_height), (0, width - origin_width)))

    def _normalize(data, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)):
        mean, std = np.asarray(mean), np.asarray(std)
        return (data - mean[None, :, None, None]) / std[None, :, None, None]

    ort = _get_onnx_session(model)

    input_ = _normalize(input_).astype(np.float32)
    output_, = ort.run(['output'], {'input': input_})
    heatmap = output_[0]
    heatmap = heatmap[:origin_height, :origin_width]

    return heatmap


def _get_bounding_box_of_text(image: ImageTyping, model: str, threshold: float) \
        -> List[Tuple[Tuple[int, int, int, int], float]]:
    heatmap = _get_heatmap_of_text(image, model)
    c_rets = cv2.findContours((heatmap * 255.0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = c_rets[0] if len(c_rets) == 2 else c_rets[1]
    bboxes = []
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        x0, y0, x1, y1 = x, y, x + w, y + h
        score = heatmap[y0:y1, x0:x1].mean().item()
        if score >= threshold:
            bboxes.append(((x0, y0, x1, y1), score))

    return bboxes


def detect_text(image: ImageTyping, model: str = _DEFAULT_MODEL, threshold: float = 0.05):
    bboxes = []
    for (x0, y0, x1, y1), score in _get_bounding_box_of_text(image, model, threshold):
        bboxes.append(((x0, y0, x1, y1), 'text', score))
    return bboxes