| | |
| | import math |
| | import os |
| | import shutil |
| | import urllib |
| | import warnings |
| |
|
| | import cv2 |
| | import mmcv |
| | import numpy as np |
| | import torch |
| | from matplotlib import pyplot as plt |
| | from PIL import Image, ImageDraw, ImageFont |
| |
|
| | import mmocr.utils as utils |
| |
|
| |
|
| | def overlay_mask_img(img, mask): |
| | """Draw mask boundaries on image for visualization. |
| | |
| | Args: |
| | img (ndarray): The input image. |
| | mask (ndarray): The instance mask. |
| | |
| | Returns: |
| | img (ndarray): The output image with instance boundaries on it. |
| | """ |
| | assert isinstance(img, np.ndarray) |
| | assert isinstance(mask, np.ndarray) |
| |
|
| | contours, _ = cv2.findContours( |
| | mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
|
| | cv2.drawContours(img, contours, -1, (0, 255, 0), 1) |
| |
|
| | return img |
| |
|
| |
|
| | def show_feature(features, names, to_uint8, out_file=None): |
| | """Visualize a list of feature maps. |
| | |
| | Args: |
| | features (list(ndarray)): The feature map list. |
| | names (list(str)): The visualized title list. |
| | to_uint8 (list(1|0)): The list indicating whether to convent |
| | feature maps to uint8. |
| | out_file (str): The output file name. If set to None, |
| | the output image will be shown without saving. |
| | """ |
| | assert utils.is_type_list(features, np.ndarray) |
| | assert utils.is_type_list(names, str) |
| | assert utils.is_type_list(to_uint8, int) |
| | assert utils.is_none_or_type(out_file, str) |
| | assert utils.equal_len(features, names, to_uint8) |
| |
|
| | num = len(features) |
| | row = col = math.ceil(math.sqrt(num)) |
| |
|
| | for i, (f, n) in enumerate(zip(features, names)): |
| | plt.subplot(row, col, i + 1) |
| | plt.title(n) |
| | if to_uint8[i]: |
| | f = f.astype(np.uint8) |
| | plt.imshow(f) |
| | if out_file is None: |
| | plt.show() |
| | else: |
| | plt.savefig(out_file) |
| |
|
| |
|
| | def show_img_boundary(img, boundary): |
| | """Show image and instance boundaires. |
| | |
| | Args: |
| | img (ndarray): The input image. |
| | boundary (list[float or int]): The input boundary. |
| | """ |
| | assert isinstance(img, np.ndarray) |
| | assert utils.is_type_list(boundary, (int, float)) |
| |
|
| | cv2.polylines( |
| | img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], |
| | True, |
| | color=(0, 255, 0), |
| | thickness=1) |
| | plt.imshow(img) |
| | plt.show() |
| |
|
| |
|
| | def show_pred_gt(preds, |
| | gts, |
| | show=False, |
| | win_name='', |
| | wait_time=0, |
| | out_file=None): |
| | """Show detection and ground truth for one image. |
| | |
| | Args: |
| | preds (list[list[float]]): The detection boundary list. |
| | gts (list[list[float]]): The ground truth boundary list. |
| | show (bool): Whether to show the image. |
| | win_name (str): The window name. |
| | wait_time (int): The value of waitKey param. |
| | out_file (str): The filename of the output. |
| | """ |
| | assert utils.is_2dlist(preds) |
| | assert utils.is_2dlist(gts) |
| | assert isinstance(show, bool) |
| | assert isinstance(win_name, str) |
| | assert isinstance(wait_time, int) |
| | assert utils.is_none_or_type(out_file, str) |
| |
|
| | p_xy = [p for boundary in preds for p in boundary] |
| | gt_xy = [g for gt in gts for g in gt] |
| |
|
| | max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0) |
| |
|
| | width = int(max_xy[0]) + 100 |
| | height = int(max_xy[1]) + 100 |
| |
|
| | img = np.ones((height, width, 3), np.int8) * 255 |
| | pred_color = mmcv.color_val('red') |
| | gt_color = mmcv.color_val('blue') |
| | thickness = 1 |
| |
|
| | for boundary in preds: |
| | cv2.polylines( |
| | img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], |
| | True, |
| | color=pred_color, |
| | thickness=thickness) |
| | for gt in gts: |
| | cv2.polylines( |
| | img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)], |
| | True, |
| | color=gt_color, |
| | thickness=thickness) |
| | if show: |
| | mmcv.imshow(img, win_name, wait_time) |
| | if out_file is not None: |
| | mmcv.imwrite(img, out_file) |
| |
|
| | return img |
| |
|
| |
|
| | def imshow_pred_boundary(img, |
| | boundaries_with_scores, |
| | labels, |
| | score_thr=0, |
| | boundary_color='blue', |
| | text_color='blue', |
| | thickness=1, |
| | font_scale=0.5, |
| | show=True, |
| | win_name='', |
| | wait_time=0, |
| | out_file=None, |
| | show_score=False): |
| | """Draw boundaries and class labels (with scores) on an image. |
| | |
| | Args: |
| | img (str or ndarray): The image to be displayed. |
| | boundaries_with_scores (list[list[float]]): Boundaries with scores. |
| | labels (list[int]): Labels of boundaries. |
| | score_thr (float): Minimum score of boundaries to be shown. |
| | boundary_color (str or tuple or :obj:`Color`): Color of boundaries. |
| | text_color (str or tuple or :obj:`Color`): Color of texts. |
| | thickness (int): Thickness of lines. |
| | font_scale (float): Font scales of texts. |
| | show (bool): Whether to show the image. |
| | win_name (str): The window name. |
| | wait_time (int): Value of waitKey param. |
| | out_file (str or None): The filename of the output. |
| | show_score (bool): Whether to show text instance score. |
| | """ |
| | assert isinstance(img, (str, np.ndarray)) |
| | assert utils.is_2dlist(boundaries_with_scores) |
| | assert utils.is_type_list(labels, int) |
| | assert utils.equal_len(boundaries_with_scores, labels) |
| | if len(boundaries_with_scores) == 0: |
| | warnings.warn('0 text found in ' + out_file) |
| | return None |
| |
|
| | utils.valid_boundary(boundaries_with_scores[0]) |
| | img = mmcv.imread(img) |
| |
|
| | scores = np.array([b[-1] for b in boundaries_with_scores]) |
| | inds = scores > score_thr |
| | boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]] |
| | scores = [scores[i] for i in np.where(inds)[0]] |
| | labels = [labels[i] for i in np.where(inds)[0]] |
| |
|
| | boundary_color = mmcv.color_val(boundary_color) |
| | text_color = mmcv.color_val(text_color) |
| | font_scale = 0.5 |
| |
|
| | for boundary, score in zip(boundaries, scores): |
| | boundary_int = np.array(boundary).astype(np.int32) |
| |
|
| | cv2.polylines( |
| | img, [boundary_int.reshape(-1, 1, 2)], |
| | True, |
| | color=boundary_color, |
| | thickness=thickness) |
| |
|
| | if show_score: |
| | label_text = f'{score:.02f}' |
| | cv2.putText(img, label_text, |
| | (boundary_int[0], boundary_int[1] - 2), |
| | cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) |
| | if show: |
| | mmcv.imshow(img, win_name, wait_time) |
| | if out_file is not None: |
| | mmcv.imwrite(img, out_file) |
| |
|
| | return img |
| |
|
| |
|
| | def imshow_text_char_boundary(img, |
| | text_quads, |
| | boundaries, |
| | char_quads, |
| | chars, |
| | show=False, |
| | thickness=1, |
| | font_scale=0.5, |
| | win_name='', |
| | wait_time=-1, |
| | out_file=None): |
| | """Draw text boxes and char boxes on img. |
| | |
| | Args: |
| | img (str or ndarray): The img to be displayed. |
| | text_quads (list[list[int|float]]): The text boxes. |
| | boundaries (list[list[int|float]]): The boundary list. |
| | char_quads (list[list[list[int|float]]]): A 2d list of char boxes. |
| | char_quads[i] is for the ith text, and char_quads[i][j] is the jth |
| | char of the ith text. |
| | chars (list[list[char]]). The string for each text box. |
| | thickness (int): Thickness of lines. |
| | font_scale (float): Font scales of texts. |
| | show (bool): Whether to show the image. |
| | win_name (str): The window name. |
| | wait_time (int): Value of waitKey param. |
| | out_file (str or None): The filename of the output. |
| | """ |
| | assert isinstance(img, (np.ndarray, str)) |
| | assert utils.is_2dlist(text_quads) |
| | assert utils.is_2dlist(boundaries) |
| | assert utils.is_3dlist(char_quads) |
| | assert utils.is_2dlist(chars) |
| | assert utils.equal_len(text_quads, char_quads, boundaries) |
| |
|
| | img = mmcv.imread(img) |
| | char_color = [mmcv.color_val('blue'), mmcv.color_val('green')] |
| | text_color = mmcv.color_val('red') |
| | text_inx = 0 |
| | for text_box, boundary, char_box, txt in zip(text_quads, boundaries, |
| | char_quads, chars): |
| | text_box = np.array(text_box) |
| | boundary = np.array(boundary) |
| |
|
| | text_box = text_box.reshape(-1, 2).astype(np.int32) |
| | cv2.polylines( |
| | img, [text_box.reshape(-1, 1, 2)], |
| | True, |
| | color=text_color, |
| | thickness=thickness) |
| | if boundary.shape[0] > 0: |
| | cv2.polylines( |
| | img, [boundary.reshape(-1, 1, 2)], |
| | True, |
| | color=text_color, |
| | thickness=thickness) |
| |
|
| | for b in char_box: |
| | b = np.array(b) |
| | c = char_color[text_inx % 2] |
| | b = b.astype(np.int32) |
| | cv2.polylines( |
| | img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness) |
| |
|
| | label_text = ''.join(txt) |
| | cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2), |
| | cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) |
| | text_inx = text_inx + 1 |
| |
|
| | if show: |
| | mmcv.imshow(img, win_name, wait_time) |
| | if out_file is not None: |
| | mmcv.imwrite(img, out_file) |
| |
|
| | return img |
| |
|
| |
|
| | def tile_image(images): |
| | """Combined multiple images to one vertically. |
| | |
| | Args: |
| | images (list[np.ndarray]): Images to be combined. |
| | """ |
| | assert isinstance(images, list) |
| | assert len(images) > 0 |
| |
|
| | for i, _ in enumerate(images): |
| | if len(images[i].shape) == 2: |
| | images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR) |
| |
|
| | widths = [img.shape[1] for img in images] |
| | heights = [img.shape[0] for img in images] |
| | h, w = sum(heights), max(widths) |
| | vis_img = np.zeros((h, w, 3), dtype=np.uint8) |
| |
|
| | offset_y = 0 |
| | for image in images: |
| | img_h, img_w = image.shape[:2] |
| | vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image |
| | offset_y += img_h |
| |
|
| | return vis_img |
| |
|
| |
|
| | def imshow_text_label(img, |
| | pred_label, |
| | gt_label, |
| | show=False, |
| | win_name='', |
| | wait_time=-1, |
| | out_file=None): |
| | """Draw predicted texts and ground truth texts on images. |
| | |
| | Args: |
| | img (str or np.ndarray): Image filename or loaded image. |
| | pred_label (str): Predicted texts. |
| | gt_label (str): Ground truth texts. |
| | show (bool): Whether to show the image. |
| | win_name (str): The window name. |
| | wait_time (int): Value of waitKey param. |
| | out_file (str): The filename of the output. |
| | """ |
| | assert isinstance(img, (np.ndarray, str)) |
| | assert isinstance(pred_label, str) |
| | assert isinstance(gt_label, str) |
| | assert isinstance(show, bool) |
| | assert isinstance(win_name, str) |
| | assert isinstance(wait_time, int) |
| |
|
| | img = mmcv.imread(img) |
| |
|
| | src_h, src_w = img.shape[:2] |
| | resize_height = 64 |
| | resize_width = int(1.0 * src_w / src_h * resize_height) |
| | img = cv2.resize(img, (resize_width, resize_height)) |
| | h, w = img.shape[:2] |
| |
|
| | if is_contain_chinese(pred_label): |
| | pred_img = draw_texts_by_pil(img, [pred_label], None) |
| | else: |
| | pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255 |
| | cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, |
| | 0.9, (0, 0, 255), 2) |
| | images = [pred_img, img] |
| |
|
| | if gt_label != '': |
| | if is_contain_chinese(gt_label): |
| | gt_img = draw_texts_by_pil(img, [gt_label], None) |
| | else: |
| | gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255 |
| | cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, |
| | 0.9, (255, 0, 0), 2) |
| | images.append(gt_img) |
| |
|
| | img = tile_image(images) |
| |
|
| | if show: |
| | mmcv.imshow(img, win_name, wait_time) |
| | if out_file is not None: |
| | mmcv.imwrite(img, out_file) |
| |
|
| | return img |
| |
|
| |
|
| | def imshow_node(img, |
| | result, |
| | boxes, |
| | idx_to_cls={}, |
| | show=False, |
| | win_name='', |
| | wait_time=-1, |
| | out_file=None): |
| |
|
| | img = mmcv.imread(img) |
| | h, w = img.shape[:2] |
| |
|
| | max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) |
| | node_pred_label = max_idx.numpy().tolist() |
| | node_pred_score = max_value.numpy().tolist() |
| |
|
| | texts, text_boxes = [], [] |
| | for i, box in enumerate(boxes): |
| | new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], |
| | [box[0], box[3]]] |
| | Pts = np.array([new_box], np.int32) |
| | cv2.polylines( |
| | img, [Pts.reshape((-1, 1, 2))], |
| | True, |
| | color=(255, 255, 0), |
| | thickness=1) |
| | x_min = int(min([point[0] for point in new_box])) |
| | y_min = int(min([point[1] for point in new_box])) |
| |
|
| | |
| | pred_label = str(node_pred_label[i]) |
| | if pred_label in idx_to_cls: |
| | pred_label = idx_to_cls[pred_label] |
| | pred_score = '{:.2f}'.format(node_pred_score[i]) |
| | text = pred_label + '(' + pred_score + ')' |
| | texts.append(text) |
| |
|
| | |
| | font_size = int( |
| | min( |
| | abs(new_box[3][1] - new_box[0][1]), |
| | abs(new_box[1][0] - new_box[0][0]))) |
| | char_num = len(text) |
| | text_box = [ |
| | x_min * 2, y_min, x_min * 2 + font_size * char_num, y_min, |
| | x_min * 2 + font_size * char_num, y_min + font_size, x_min * 2, |
| | y_min + font_size |
| | ] |
| | text_boxes.append(text_box) |
| |
|
| | pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 |
| | pred_img = draw_texts_by_pil( |
| | pred_img, texts, text_boxes, draw_box=False, on_ori_img=True) |
| |
|
| | vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 |
| | vis_img[:, :w] = img |
| | vis_img[:, w:] = pred_img |
| |
|
| | if show: |
| | mmcv.imshow(vis_img, win_name, wait_time) |
| | if out_file is not None: |
| | mmcv.imwrite(vis_img, out_file) |
| |
|
| | return vis_img |
| |
|
| |
|
| | def gen_color(): |
| | """Generate BGR color schemes.""" |
| | color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249), |
| | (123, 151, 138), (187, 200, 178), (148, 137, 69), |
| | (169, 200, 200), (155, 175, 131), (154, 194, 182), |
| | (178, 190, 137), (140, 211, 222), (83, 156, 222)] |
| | return color_list |
| |
|
| |
|
| | def draw_polygons(img, polys): |
| | """Draw polygons on image. |
| | |
| | Args: |
| | img (np.ndarray): The original image. |
| | polys (list[list[float]]): Detected polygons. |
| | Return: |
| | out_img (np.ndarray): Visualized image. |
| | """ |
| | dst_img = img.copy() |
| | color_list = gen_color() |
| | out_img = dst_img |
| | for idx, poly in enumerate(polys): |
| | poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32) |
| | cv2.drawContours( |
| | img, |
| | np.array([poly]), |
| | -1, |
| | color_list[idx % len(color_list)], |
| | thickness=cv2.FILLED) |
| | out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0) |
| | return out_img |
| |
|
| |
|
| | def get_optimal_font_scale(text, width): |
| | """Get optimal font scale for cv2.putText. |
| | |
| | Args: |
| | text (str): Text in one box. |
| | width (int): The box width. |
| | """ |
| | for scale in reversed(range(0, 60, 1)): |
| | textSize = cv2.getTextSize( |
| | text, |
| | fontFace=cv2.FONT_HERSHEY_SIMPLEX, |
| | fontScale=scale / 10, |
| | thickness=1) |
| | new_width = textSize[0][0] |
| | if new_width <= width: |
| | return scale / 10 |
| | return 1 |
| |
|
| |
|
| | def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False): |
| | """Draw boxes and texts on empty img. |
| | |
| | Args: |
| | img (np.ndarray): The original image. |
| | texts (list[str]): Recognized texts. |
| | boxes (list[list[float]]): Detected bounding boxes. |
| | draw_box (bool): Whether draw box or not. If False, draw text only. |
| | on_ori_img (bool): If True, draw box and text on input image, |
| | else, on a new empty image. |
| | Return: |
| | out_img (np.ndarray): Visualized image. |
| | """ |
| | color_list = gen_color() |
| | h, w = img.shape[:2] |
| | if boxes is None: |
| | boxes = [[0, 0, w, 0, w, h, 0, h]] |
| | assert len(texts) == len(boxes) |
| |
|
| | if on_ori_img: |
| | out_img = img |
| | else: |
| | out_img = np.ones((h, w, 3), dtype=np.uint8) * 255 |
| | for idx, (box, text) in enumerate(zip(boxes, texts)): |
| | if draw_box: |
| | new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])] |
| | Pts = np.array([new_box], np.int32) |
| | cv2.polylines( |
| | out_img, [Pts.reshape((-1, 1, 2))], |
| | True, |
| | color=color_list[idx % len(color_list)], |
| | thickness=1) |
| | min_x = int(min(box[0::2])) |
| | max_y = int( |
| | np.mean(np.array(box[1::2])) + 0.2 * |
| | (max(box[1::2]) - min(box[1::2]))) |
| | font_scale = get_optimal_font_scale( |
| | text, int(max(box[0::2]) - min(box[0::2]))) |
| | cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX, |
| | font_scale, (0, 0, 0), 1) |
| |
|
| | return out_img |
| |
|
| |
|
| | def draw_texts_by_pil(img, |
| | texts, |
| | boxes=None, |
| | draw_box=True, |
| | on_ori_img=False, |
| | font_size=None, |
| | fill_color=None, |
| | draw_pos=None, |
| | return_text_size=False): |
| | """Draw boxes and texts on empty image, especially for Chinese. |
| | |
| | Args: |
| | img (np.ndarray): The original image. |
| | texts (list[str]): Recognized texts. |
| | boxes (list[list[float]]): Detected bounding boxes. |
| | draw_box (bool): Whether draw box or not. If False, draw text only. |
| | on_ori_img (bool): If True, draw box and text on input image, |
| | else on a new empty image. |
| | font_size (int, optional): Size to create a font object for a font. |
| | fill_color (tuple(int), optional): Fill color for text. |
| | draw_pos (list[tuple(int)], optional): Start point to draw each text. |
| | return_text_size (bool): If True, return the list of text size. |
| | |
| | Returns: |
| | (np.ndarray, list[tuple]) or np.ndarray: Return a tuple |
| | ``(out_img, text_sizes)``, where ``out_img`` is the output image |
| | with texts drawn on it and ``text_sizes`` are the size of drawing |
| | texts. If ``return_text_size`` is False, only the output image will be |
| | returned. |
| | """ |
| |
|
| | color_list = gen_color() |
| | h, w = img.shape[:2] |
| | if boxes is None: |
| | boxes = [[0, 0, w, 0, w, h, 0, h]] |
| | if draw_pos is None: |
| | draw_pos = [None for _ in texts] |
| | assert len(boxes) == len(texts) == len(draw_pos) |
| |
|
| | if fill_color is None: |
| | fill_color = (0, 0, 0) |
| |
|
| | if on_ori_img: |
| | out_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
| | else: |
| | out_img = Image.new('RGB', (w, h), color=(255, 255, 255)) |
| | out_draw = ImageDraw.Draw(out_img) |
| |
|
| | text_sizes = [] |
| | for idx, (box, text, ori_point) in enumerate(zip(boxes, texts, draw_pos)): |
| | if len(text) == 0: |
| | continue |
| | min_x, max_x = min(box[0::2]), max(box[0::2]) |
| | min_y, max_y = min(box[1::2]), max(box[1::2]) |
| | color = tuple(list(color_list[idx % len(color_list)])[::-1]) |
| | if draw_box: |
| | out_draw.line(box, fill=color, width=1) |
| | dirname, _ = os.path.split(os.path.abspath(__file__)) |
| | font_path = os.path.join(dirname, 'font.TTF') |
| | if not os.path.exists(font_path): |
| | url = ('https://download.openmmlab.com/mmocr/data/font.TTF') |
| | print(f'Downloading {url} ...') |
| | local_filename, _ = urllib.request.urlretrieve(url) |
| | shutil.move(local_filename, font_path) |
| | tmp_font_size = font_size |
| | if tmp_font_size is None: |
| | box_width = max(max_x - min_x, max_y - min_y) |
| | tmp_font_size = int(0.9 * box_width / len(text)) |
| | fnt = ImageFont.truetype(font_path, tmp_font_size) |
| | if ori_point is None: |
| | ori_point = (min_x + 1, min_y + 1) |
| | out_draw.text(ori_point, text, font=fnt, fill=fill_color) |
| | text_sizes.append(fnt.getsize(text)) |
| |
|
| | del out_draw |
| |
|
| | out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR) |
| |
|
| | if return_text_size: |
| | return out_img, text_sizes |
| |
|
| | return out_img |
| |
|
| |
|
| | def is_contain_chinese(check_str): |
| | """Check whether string contains Chinese or not. |
| | |
| | Args: |
| | check_str (str): String to be checked. |
| | |
| | Return True if contains Chinese, else False. |
| | """ |
| | for ch in check_str: |
| | if u'\u4e00' <= ch <= u'\u9fff': |
| | return True |
| | return False |
| |
|
| |
|
| | def det_recog_show_result(img, end2end_res, out_file=None): |
| | """Draw `result`(boxes and texts) on `img`. |
| | |
| | Args: |
| | img (str or np.ndarray): The image to be displayed. |
| | end2end_res (dict): Text detect and recognize results. |
| | out_file (str): Image path where the visualized image should be saved. |
| | Return: |
| | out_img (np.ndarray): Visualized image. |
| | """ |
| | img = mmcv.imread(img) |
| | boxes, texts = [], [] |
| | for res in end2end_res['result']: |
| | boxes.append(res['box']) |
| | texts.append(res['text']) |
| | box_vis_img = draw_polygons(img, boxes) |
| |
|
| | if is_contain_chinese(''.join(texts)): |
| | text_vis_img = draw_texts_by_pil(img, texts, boxes) |
| | else: |
| | text_vis_img = draw_texts(img, texts, boxes) |
| |
|
| | h, w = img.shape[:2] |
| | out_img = np.ones((h, w * 2, 3), dtype=np.uint8) |
| | out_img[:, :w, :] = box_vis_img |
| | out_img[:, w:, :] = text_vis_img |
| |
|
| | if out_file: |
| | mmcv.imwrite(out_img, out_file) |
| |
|
| | return out_img |
| |
|
| |
|
| | def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5): |
| | """Draw text and their relationship on empty images. |
| | |
| | Args: |
| | img (np.ndarray): The original image. |
| | result (dict): The result of model forward_test, including: |
| | - img_metas (list[dict]): List of meta information dictionary. |
| | - nodes (Tensor): Node prediction with size: |
| | number_node * node_classes. |
| | - edges (Tensor): Edge prediction with size: number_edge * 2. |
| | edge_thresh (float): Score threshold for edge classification. |
| | keynode_thresh (float): Score threshold for node |
| | (``key``) classification. |
| | |
| | Returns: |
| | np.ndarray: The image with key, value and relation drawn on it. |
| | """ |
| |
|
| | h, w = img.shape[:2] |
| |
|
| | vis_area_width = w // 3 * 2 |
| | vis_area_height = h |
| | dist_key_to_value = vis_area_width // 2 |
| | dist_pair_to_pair = 30 |
| |
|
| | bbox_x1 = dist_pair_to_pair |
| | bbox_y1 = 0 |
| |
|
| | new_w = vis_area_width |
| | new_h = vis_area_height |
| | pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255 |
| |
|
| | nodes = result['nodes'].detach().cpu() |
| | texts = result['img_metas'][0]['ori_texts'] |
| | num_nodes = result['nodes'].size(0) |
| | edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes) |
| |
|
| | |
| | |
| | |
| | pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True) |
| | pairs = (pairs[0].numpy().tolist(), pairs[1].numpy().tolist()) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1) |
| | for n1, n2 in zip(*pairs) if n1 < n2] |
| |
|
| | result_pairs.sort() |
| | result_pairs_score = [ |
| | torch.max(edges[n1, n2], edges[n2, n1]) for n1, n2 in result_pairs |
| | ] |
| |
|
| | key_current_idx = -1 |
| | pos_current = (-1, -1) |
| | newline_flag = False |
| |
|
| | key_font_size = 15 |
| | value_font_size = 15 |
| | key_font_color = (0, 0, 0) |
| | value_font_color = (0, 0, 255) |
| | arrow_color = (0, 0, 255) |
| | score_color = (0, 255, 0) |
| | for pair, pair_score in zip(result_pairs, result_pairs_score): |
| | key_idx = pair[0] |
| | if nodes[key_idx, 1] < keynode_thresh: |
| | continue |
| | if key_idx != key_current_idx: |
| | |
| | bbox_y1 += 10 |
| | |
| | if newline_flag: |
| | bbox_x1 += vis_area_width |
| | tmp_img = np.ones( |
| | (new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255 |
| | tmp_img[:new_h, :new_w] = pred_edge_img |
| | pred_edge_img = tmp_img |
| | new_w += vis_area_width |
| | newline_flag = False |
| | bbox_y1 = 10 |
| | key_text = texts[key_idx] |
| | key_pos = (bbox_x1, bbox_y1) |
| | value_idx = pair[1] |
| | value_text = texts[value_idx] |
| | value_pos = (bbox_x1 + dist_key_to_value, bbox_y1) |
| | if key_idx != key_current_idx: |
| | |
| | key_current_idx = key_idx |
| | pred_edge_img, text_sizes = draw_texts_by_pil( |
| | pred_edge_img, [key_text], |
| | draw_box=False, |
| | on_ori_img=True, |
| | font_size=key_font_size, |
| | fill_color=key_font_color, |
| | draw_pos=[key_pos], |
| | return_text_size=True) |
| | pos_right_bottom = (key_pos[0] + text_sizes[0][0], |
| | key_pos[1] + text_sizes[0][1]) |
| | pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10) |
| | pred_edge_img = cv2.arrowedLine( |
| | pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10), |
| | (bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color, |
| | 1) |
| | score_pos_x = int( |
| | (pos_right_bottom[0] + bbox_x1 + dist_key_to_value) / 2.) |
| | score_pos_y = bbox_y1 + 10 - int(key_font_size * 0.3) |
| | else: |
| | |
| | if newline_flag: |
| | tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3), |
| | dtype=np.uint8) * 255 |
| | tmp_img[:new_h, :new_w] = pred_edge_img |
| | pred_edge_img = tmp_img |
| | new_h += dist_pair_to_pair |
| | pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current, |
| | (bbox_x1 + dist_key_to_value - 5, |
| | bbox_y1 + 10), arrow_color, 1) |
| | score_pos_x = int( |
| | (pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.) |
| | score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.) |
| | |
| | cv2.putText(pred_edge_img, '{:.2f}'.format(pair_score), |
| | (score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4, |
| | score_color) |
| | |
| | pred_edge_img = draw_texts_by_pil( |
| | pred_edge_img, [value_text], |
| | draw_box=False, |
| | on_ori_img=True, |
| | font_size=value_font_size, |
| | fill_color=value_font_color, |
| | draw_pos=[value_pos], |
| | return_text_size=False) |
| | bbox_y1 += dist_pair_to_pair |
| | if bbox_y1 + dist_pair_to_pair >= new_h: |
| | newline_flag = True |
| |
|
| | return pred_edge_img |
| |
|
| |
|
| | def imshow_edge(img, |
| | result, |
| | boxes, |
| | show=False, |
| | win_name='', |
| | wait_time=-1, |
| | out_file=None): |
| | """Display the prediction results of the nodes and edges of the KIE model. |
| | |
| | Args: |
| | img (np.ndarray): The original image. |
| | result (dict): The result of model forward_test, including: |
| | - img_metas (list[dict]): List of meta information dictionary. |
| | - nodes (Tensor): Node prediction with size: \ |
| | number_node * node_classes. |
| | - edges (Tensor): Edge prediction with size: number_edge * 2. |
| | boxes (list): The text boxes corresponding to the nodes. |
| | show (bool): Whether to show the image. Default: False. |
| | win_name (str): The window name. Default: '' |
| | wait_time (float): Value of waitKey param. Default: 0. |
| | out_file (str or None): The filename to write the image. |
| | Default: None. |
| | |
| | Returns: |
| | np.ndarray: The image with key, value and relation drawn on it. |
| | """ |
| | img = mmcv.imread(img) |
| | h, w = img.shape[:2] |
| | color_list = gen_color() |
| |
|
| | for i, box in enumerate(boxes): |
| | new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], |
| | [box[0], box[3]]] |
| | Pts = np.array([new_box], np.int32) |
| | cv2.polylines( |
| | img, [Pts.reshape((-1, 1, 2))], |
| | True, |
| | color=color_list[i % len(color_list)], |
| | thickness=1) |
| |
|
| | pred_img_h = h |
| | pred_img_w = w |
| |
|
| | pred_edge_img = draw_edge_result(img, result) |
| | pred_img_h = max(pred_img_h, pred_edge_img.shape[0]) |
| | pred_img_w += pred_edge_img.shape[1] |
| |
|
| | vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8) |
| | vis_img[:h, :w] = img |
| | vis_img[:, w:] = 255 |
| |
|
| | height_t, width_t = pred_edge_img.shape[:2] |
| | vis_img[:height_t, w:(w + width_t)] = pred_edge_img |
| |
|
| | if show: |
| | mmcv.imshow(vis_img, win_name, wait_time) |
| | if out_file is not None: |
| | mmcv.imwrite(vis_img, out_file) |
| | res_dic = { |
| | 'boxes': boxes, |
| | 'nodes': result['nodes'].detach().cpu(), |
| | 'edges': result['edges'].detach().cpu(), |
| | 'metas': result['img_metas'][0] |
| | } |
| | mmcv.dump(res_dic, f'{out_file}_res.pkl') |
| |
|
| | return vis_img |
| |
|