deprem-ocr / ocr /postprocess /east_postprocess.py
Goodsea's picture
paddleocr
fc8c192
raw history blame
No virus
4.34 kB
from __future__ import absolute_import, division, print_function
import cv2
import numpy as np
import paddle
from .locality_aware_nms import nms_locality
class EASTPostProcess(object):
"""
The post process for EAST.
"""
def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs):
self.score_thresh = score_thresh
self.cover_thresh = cover_thresh
self.nms_thresh = nms_thresh
def restore_rectangle_quad(self, origin, geometry):
"""
Restore rectangle from quadrangle.
"""
# quad
origin_concat = np.concatenate(
(origin, origin, origin, origin), axis=1
) # (n, 8)
pred_quads = origin_concat - geometry
pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
return pred_quads
def detect(
self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2
):
"""
restore text boxes from score map and geo map
"""
score_map = score_map[0]
geo_map = np.swapaxes(geo_map, 1, 0)
geo_map = np.swapaxes(geo_map, 1, 2)
# filter the score map
xy_text = np.argwhere(score_map > score_thresh)
if len(xy_text) == 0:
return []
# sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 0])]
# restore quad proposals
text_box_restored = self.restore_rectangle_quad(
xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :]
)
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
boxes[:, :8] = text_box_restored.reshape((-1, 8))
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
try:
import lanms
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
except:
print(
"you should install lanms by pip3 install lanms-nova to speed up nms_locality"
)
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if boxes.shape[0] == 0:
return []
# Here we filter some low score boxes by the average score map,
# this is different from the orginal paper.
for i, box in enumerate(boxes):
mask = np.zeros_like(score_map, dtype=np.uint8)
cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
boxes[i, 8] = cv2.mean(score_map, mask)[0]
boxes = boxes[boxes[:, 8] > cover_thresh]
return boxes
def sort_poly(self, p):
"""
Sort polygons.
"""
min_axis = np.argmin(np.sum(p, axis=1))
p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
return p
else:
return p[[0, 3, 2, 1]]
def __call__(self, outs_dict, shape_list):
score_list = outs_dict["f_score"]
geo_list = outs_dict["f_geo"]
if isinstance(score_list, paddle.Tensor):
score_list = score_list.numpy()
geo_list = geo_list.numpy()
img_num = len(shape_list)
dt_boxes_list = []
for ino in range(img_num):
score = score_list[ino]
geo = geo_list[ino]
boxes = self.detect(
score_map=score,
geo_map=geo,
score_thresh=self.score_thresh,
cover_thresh=self.cover_thresh,
nms_thresh=self.nms_thresh,
)
boxes_norm = []
if len(boxes) > 0:
h, w = score.shape[1:]
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
boxes = boxes[:, :8].reshape((-1, 4, 2))
boxes[:, :, 0] /= ratio_w
boxes[:, :, 1] /= ratio_h
for i_box, box in enumerate(boxes):
box = self.sort_poly(box.astype(np.int32))
if (
np.linalg.norm(box[0] - box[1]) < 5
or np.linalg.norm(box[3] - box[0]) < 5
):
continue
boxes_norm.append(box)
dt_boxes_list.append({"points": np.array(boxes_norm)})
return dt_boxes_list