from __future__ import absolute_import, division, print_function import cv2 import numpy as np import paddle from .locality_aware_nms import nms_locality class EASTPostProcess(object): """ The post process for EAST. """ def __init__(self, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2, **kwargs): self.score_thresh = score_thresh self.cover_thresh = cover_thresh self.nms_thresh = nms_thresh def restore_rectangle_quad(self, origin, geometry): """ Restore rectangle from quadrangle. """ # quad origin_concat = np.concatenate( (origin, origin, origin, origin), axis=1 ) # (n, 8) pred_quads = origin_concat - geometry pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2) return pred_quads def detect( self, score_map, geo_map, score_thresh=0.8, cover_thresh=0.1, nms_thresh=0.2 ): """ restore text boxes from score map and geo map """ score_map = score_map[0] geo_map = np.swapaxes(geo_map, 1, 0) geo_map = np.swapaxes(geo_map, 1, 2) # filter the score map xy_text = np.argwhere(score_map > score_thresh) if len(xy_text) == 0: return [] # sort the text boxes via the y axis xy_text = xy_text[np.argsort(xy_text[:, 0])] # restore quad proposals text_box_restored = self.restore_rectangle_quad( xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :] ) boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) boxes[:, :8] = text_box_restored.reshape((-1, 8)) boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] try: import lanms boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh) except: print( "you should install lanms by pip3 install lanms-nova to speed up nms_locality" ) boxes = nms_locality(boxes.astype(np.float64), nms_thresh) if boxes.shape[0] == 0: return [] # Here we filter some low score boxes by the average score map, # this is different from the orginal paper. for i, box in enumerate(boxes): mask = np.zeros_like(score_map, dtype=np.uint8) cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) boxes[i, 8] = cv2.mean(score_map, mask)[0] boxes = boxes[boxes[:, 8] > cover_thresh] return boxes def sort_poly(self, p): """ Sort polygons. """ min_axis = np.argmin(np.sum(p, axis=1)) p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): return p else: return p[[0, 3, 2, 1]] def __call__(self, outs_dict, shape_list): score_list = outs_dict["f_score"] geo_list = outs_dict["f_geo"] if isinstance(score_list, paddle.Tensor): score_list = score_list.numpy() geo_list = geo_list.numpy() img_num = len(shape_list) dt_boxes_list = [] for ino in range(img_num): score = score_list[ino] geo = geo_list[ino] boxes = self.detect( score_map=score, geo_map=geo, score_thresh=self.score_thresh, cover_thresh=self.cover_thresh, nms_thresh=self.nms_thresh, ) boxes_norm = [] if len(boxes) > 0: h, w = score.shape[1:] src_h, src_w, ratio_h, ratio_w = shape_list[ino] boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h for i_box, box in enumerate(boxes): box = self.sort_poly(box.astype(np.int32)) if ( np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5 ): continue boxes_norm.append(box) dt_boxes_list.append({"points": np.array(boxes_norm)}) return dt_boxes_list