Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import os | |
import shutil | |
import urllib | |
import warnings | |
import cv2 | |
import mmcv | |
import numpy as np | |
import torch | |
from matplotlib import pyplot as plt | |
from PIL import Image, ImageDraw, ImageFont | |
import mmocr.utils as utils | |
def overlay_mask_img(img, mask): | |
"""Draw mask boundaries on image for visualization. | |
Args: | |
img (ndarray): The input image. | |
mask (ndarray): The instance mask. | |
Returns: | |
img (ndarray): The output image with instance boundaries on it. | |
""" | |
assert isinstance(img, np.ndarray) | |
assert isinstance(mask, np.ndarray) | |
contours, _ = cv2.findContours( | |
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cv2.drawContours(img, contours, -1, (0, 255, 0), 1) | |
return img | |
def show_feature(features, names, to_uint8, out_file=None): | |
"""Visualize a list of feature maps. | |
Args: | |
features (list(ndarray)): The feature map list. | |
names (list(str)): The visualized title list. | |
to_uint8 (list(1|0)): The list indicating whether to convent | |
feature maps to uint8. | |
out_file (str): The output file name. If set to None, | |
the output image will be shown without saving. | |
""" | |
assert utils.is_type_list(features, np.ndarray) | |
assert utils.is_type_list(names, str) | |
assert utils.is_type_list(to_uint8, int) | |
assert utils.is_none_or_type(out_file, str) | |
assert utils.equal_len(features, names, to_uint8) | |
num = len(features) | |
row = col = math.ceil(math.sqrt(num)) | |
for i, (f, n) in enumerate(zip(features, names)): | |
plt.subplot(row, col, i + 1) | |
plt.title(n) | |
if to_uint8[i]: | |
f = f.astype(np.uint8) | |
plt.imshow(f) | |
if out_file is None: | |
plt.show() | |
else: | |
plt.savefig(out_file) | |
def show_img_boundary(img, boundary): | |
"""Show image and instance boundaires. | |
Args: | |
img (ndarray): The input image. | |
boundary (list[float or int]): The input boundary. | |
""" | |
assert isinstance(img, np.ndarray) | |
assert utils.is_type_list(boundary, (int, float)) | |
cv2.polylines( | |
img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], | |
True, | |
color=(0, 255, 0), | |
thickness=1) | |
plt.imshow(img) | |
plt.show() | |
def show_pred_gt(preds, | |
gts, | |
show=False, | |
win_name='', | |
wait_time=0, | |
out_file=None): | |
"""Show detection and ground truth for one image. | |
Args: | |
preds (list[list[float]]): The detection boundary list. | |
gts (list[list[float]]): The ground truth boundary list. | |
show (bool): Whether to show the image. | |
win_name (str): The window name. | |
wait_time (int): The value of waitKey param. | |
out_file (str): The filename of the output. | |
""" | |
assert utils.is_2dlist(preds) | |
assert utils.is_2dlist(gts) | |
assert isinstance(show, bool) | |
assert isinstance(win_name, str) | |
assert isinstance(wait_time, int) | |
assert utils.is_none_or_type(out_file, str) | |
p_xy = [p for boundary in preds for p in boundary] | |
gt_xy = [g for gt in gts for g in gt] | |
max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0) | |
width = int(max_xy[0]) + 100 | |
height = int(max_xy[1]) + 100 | |
img = np.ones((height, width, 3), np.int8) * 255 | |
pred_color = mmcv.color_val('red') | |
gt_color = mmcv.color_val('blue') | |
thickness = 1 | |
for boundary in preds: | |
cv2.polylines( | |
img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], | |
True, | |
color=pred_color, | |
thickness=thickness) | |
for gt in gts: | |
cv2.polylines( | |
img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)], | |
True, | |
color=gt_color, | |
thickness=thickness) | |
if show: | |
mmcv.imshow(img, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(img, out_file) | |
return img | |
def imshow_pred_boundary(img, | |
boundaries_with_scores, | |
labels, | |
score_thr=0, | |
boundary_color='blue', | |
text_color='blue', | |
thickness=1, | |
font_scale=0.5, | |
show=True, | |
win_name='', | |
wait_time=0, | |
out_file=None, | |
show_score=False): | |
"""Draw boundaries and class labels (with scores) on an image. | |
Args: | |
img (str or ndarray): The image to be displayed. | |
boundaries_with_scores (list[list[float]]): Boundaries with scores. | |
labels (list[int]): Labels of boundaries. | |
score_thr (float): Minimum score of boundaries to be shown. | |
boundary_color (str or tuple or :obj:`Color`): Color of boundaries. | |
text_color (str or tuple or :obj:`Color`): Color of texts. | |
thickness (int): Thickness of lines. | |
font_scale (float): Font scales of texts. | |
show (bool): Whether to show the image. | |
win_name (str): The window name. | |
wait_time (int): Value of waitKey param. | |
out_file (str or None): The filename of the output. | |
show_score (bool): Whether to show text instance score. | |
""" | |
assert isinstance(img, (str, np.ndarray)) | |
assert utils.is_2dlist(boundaries_with_scores) | |
assert utils.is_type_list(labels, int) | |
assert utils.equal_len(boundaries_with_scores, labels) | |
if len(boundaries_with_scores) == 0: | |
warnings.warn('0 text found in ' + out_file) | |
return None | |
utils.valid_boundary(boundaries_with_scores[0]) | |
img = mmcv.imread(img) | |
scores = np.array([b[-1] for b in boundaries_with_scores]) | |
inds = scores > score_thr | |
boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]] | |
scores = [scores[i] for i in np.where(inds)[0]] | |
labels = [labels[i] for i in np.where(inds)[0]] | |
boundary_color = mmcv.color_val(boundary_color) | |
text_color = mmcv.color_val(text_color) | |
font_scale = 0.5 | |
for boundary, score in zip(boundaries, scores): | |
boundary_int = np.array(boundary).astype(np.int32) | |
cv2.polylines( | |
img, [boundary_int.reshape(-1, 1, 2)], | |
True, | |
color=boundary_color, | |
thickness=thickness) | |
if show_score: | |
label_text = f'{score:.02f}' | |
cv2.putText(img, label_text, | |
(boundary_int[0], boundary_int[1] - 2), | |
cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) | |
if show: | |
mmcv.imshow(img, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(img, out_file) | |
return img | |
def imshow_text_char_boundary(img, | |
text_quads, | |
boundaries, | |
char_quads, | |
chars, | |
show=False, | |
thickness=1, | |
font_scale=0.5, | |
win_name='', | |
wait_time=-1, | |
out_file=None): | |
"""Draw text boxes and char boxes on img. | |
Args: | |
img (str or ndarray): The img to be displayed. | |
text_quads (list[list[int|float]]): The text boxes. | |
boundaries (list[list[int|float]]): The boundary list. | |
char_quads (list[list[list[int|float]]]): A 2d list of char boxes. | |
char_quads[i] is for the ith text, and char_quads[i][j] is the jth | |
char of the ith text. | |
chars (list[list[char]]). The string for each text box. | |
thickness (int): Thickness of lines. | |
font_scale (float): Font scales of texts. | |
show (bool): Whether to show the image. | |
win_name (str): The window name. | |
wait_time (int): Value of waitKey param. | |
out_file (str or None): The filename of the output. | |
""" | |
assert isinstance(img, (np.ndarray, str)) | |
assert utils.is_2dlist(text_quads) | |
assert utils.is_2dlist(boundaries) | |
assert utils.is_3dlist(char_quads) | |
assert utils.is_2dlist(chars) | |
assert utils.equal_len(text_quads, char_quads, boundaries) | |
img = mmcv.imread(img) | |
char_color = [mmcv.color_val('blue'), mmcv.color_val('green')] | |
text_color = mmcv.color_val('red') | |
text_inx = 0 | |
for text_box, boundary, char_box, txt in zip(text_quads, boundaries, | |
char_quads, chars): | |
text_box = np.array(text_box) | |
boundary = np.array(boundary) | |
text_box = text_box.reshape(-1, 2).astype(np.int32) | |
cv2.polylines( | |
img, [text_box.reshape(-1, 1, 2)], | |
True, | |
color=text_color, | |
thickness=thickness) | |
if boundary.shape[0] > 0: | |
cv2.polylines( | |
img, [boundary.reshape(-1, 1, 2)], | |
True, | |
color=text_color, | |
thickness=thickness) | |
for b in char_box: | |
b = np.array(b) | |
c = char_color[text_inx % 2] | |
b = b.astype(np.int32) | |
cv2.polylines( | |
img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness) | |
label_text = ''.join(txt) | |
cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2), | |
cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) | |
text_inx = text_inx + 1 | |
if show: | |
mmcv.imshow(img, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(img, out_file) | |
return img | |
def tile_image(images): | |
"""Combined multiple images to one vertically. | |
Args: | |
images (list[np.ndarray]): Images to be combined. | |
""" | |
assert isinstance(images, list) | |
assert len(images) > 0 | |
for i, _ in enumerate(images): | |
if len(images[i].shape) == 2: | |
images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR) | |
widths = [img.shape[1] for img in images] | |
heights = [img.shape[0] for img in images] | |
h, w = sum(heights), max(widths) | |
vis_img = np.zeros((h, w, 3), dtype=np.uint8) | |
offset_y = 0 | |
for image in images: | |
img_h, img_w = image.shape[:2] | |
vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image | |
offset_y += img_h | |
return vis_img | |
def imshow_text_label(img, | |
pred_label, | |
gt_label, | |
show=False, | |
win_name='', | |
wait_time=-1, | |
out_file=None): | |
"""Draw predicted texts and ground truth texts on images. | |
Args: | |
img (str or np.ndarray): Image filename or loaded image. | |
pred_label (str): Predicted texts. | |
gt_label (str): Ground truth texts. | |
show (bool): Whether to show the image. | |
win_name (str): The window name. | |
wait_time (int): Value of waitKey param. | |
out_file (str): The filename of the output. | |
""" | |
assert isinstance(img, (np.ndarray, str)) | |
assert isinstance(pred_label, str) | |
assert isinstance(gt_label, str) | |
assert isinstance(show, bool) | |
assert isinstance(win_name, str) | |
assert isinstance(wait_time, int) | |
img = mmcv.imread(img) | |
src_h, src_w = img.shape[:2] | |
resize_height = 64 | |
resize_width = int(1.0 * src_w / src_h * resize_height) | |
img = cv2.resize(img, (resize_width, resize_height)) | |
h, w = img.shape[:2] | |
if is_contain_chinese(pred_label): | |
pred_img = draw_texts_by_pil(img, [pred_label], None) | |
else: | |
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255 | |
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, | |
0.9, (0, 0, 255), 2) | |
images = [pred_img, img] | |
if gt_label != '': | |
if is_contain_chinese(gt_label): | |
gt_img = draw_texts_by_pil(img, [gt_label], None) | |
else: | |
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255 | |
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, | |
0.9, (255, 0, 0), 2) | |
images.append(gt_img) | |
img = tile_image(images) | |
if show: | |
mmcv.imshow(img, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(img, out_file) | |
return img | |
def imshow_node(img, | |
result, | |
boxes, | |
idx_to_cls={}, | |
show=False, | |
win_name='', | |
wait_time=-1, | |
out_file=None): | |
img = mmcv.imread(img) | |
h, w = img.shape[:2] | |
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) | |
node_pred_label = max_idx.numpy().tolist() | |
node_pred_score = max_value.numpy().tolist() | |
texts, text_boxes = [], [] | |
for i, box in enumerate(boxes): | |
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], | |
[box[0], box[3]]] | |
Pts = np.array([new_box], np.int32) | |
cv2.polylines( | |
img, [Pts.reshape((-1, 1, 2))], | |
True, | |
color=(255, 255, 0), | |
thickness=1) | |
x_min = int(min([point[0] for point in new_box])) | |
y_min = int(min([point[1] for point in new_box])) | |
# text | |
pred_label = str(node_pred_label[i]) | |
if pred_label in idx_to_cls: | |
pred_label = idx_to_cls[pred_label] | |
pred_score = '{:.2f}'.format(node_pred_score[i]) | |
text = pred_label + '(' + pred_score + ')' | |
texts.append(text) | |
# text box | |
font_size = int( | |
min( | |
abs(new_box[3][1] - new_box[0][1]), | |
abs(new_box[1][0] - new_box[0][0]))) | |
char_num = len(text) | |
text_box = [ | |
x_min * 2, y_min, x_min * 2 + font_size * char_num, y_min, | |
x_min * 2 + font_size * char_num, y_min + font_size, x_min * 2, | |
y_min + font_size | |
] | |
text_boxes.append(text_box) | |
pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 | |
pred_img = draw_texts_by_pil( | |
pred_img, texts, text_boxes, draw_box=False, on_ori_img=True) | |
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 | |
vis_img[:, :w] = img | |
vis_img[:, w:] = pred_img | |
if show: | |
mmcv.imshow(vis_img, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(vis_img, out_file) | |
return vis_img | |
def gen_color(): | |
"""Generate BGR color schemes.""" | |
color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249), | |
(123, 151, 138), (187, 200, 178), (148, 137, 69), | |
(169, 200, 200), (155, 175, 131), (154, 194, 182), | |
(178, 190, 137), (140, 211, 222), (83, 156, 222)] | |
return color_list | |
def draw_polygons(img, polys): | |
"""Draw polygons on image. | |
Args: | |
img (np.ndarray): The original image. | |
polys (list[list[float]]): Detected polygons. | |
Return: | |
out_img (np.ndarray): Visualized image. | |
""" | |
dst_img = img.copy() | |
color_list = gen_color() | |
out_img = dst_img | |
for idx, poly in enumerate(polys): | |
poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32) | |
cv2.drawContours( | |
img, | |
np.array([poly]), | |
-1, | |
color_list[idx % len(color_list)], | |
thickness=cv2.FILLED) | |
out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0) | |
return out_img | |
def get_optimal_font_scale(text, width): | |
"""Get optimal font scale for cv2.putText. | |
Args: | |
text (str): Text in one box. | |
width (int): The box width. | |
""" | |
for scale in reversed(range(0, 60, 1)): | |
textSize = cv2.getTextSize( | |
text, | |
fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
fontScale=scale / 10, | |
thickness=1) | |
new_width = textSize[0][0] | |
if new_width <= width: | |
return scale / 10 | |
return 1 | |
def draw_texts(img, texts, boxes=None, draw_box=True, on_ori_img=False): | |
"""Draw boxes and texts on empty img. | |
Args: | |
img (np.ndarray): The original image. | |
texts (list[str]): Recognized texts. | |
boxes (list[list[float]]): Detected bounding boxes. | |
draw_box (bool): Whether draw box or not. If False, draw text only. | |
on_ori_img (bool): If True, draw box and text on input image, | |
else, on a new empty image. | |
Return: | |
out_img (np.ndarray): Visualized image. | |
""" | |
color_list = gen_color() | |
h, w = img.shape[:2] | |
if boxes is None: | |
boxes = [[0, 0, w, 0, w, h, 0, h]] | |
assert len(texts) == len(boxes) | |
if on_ori_img: | |
out_img = img | |
else: | |
out_img = np.ones((h, w, 3), dtype=np.uint8) * 255 | |
for idx, (box, text) in enumerate(zip(boxes, texts)): | |
if draw_box: | |
new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])] | |
Pts = np.array([new_box], np.int32) | |
cv2.polylines( | |
out_img, [Pts.reshape((-1, 1, 2))], | |
True, | |
color=color_list[idx % len(color_list)], | |
thickness=1) | |
min_x = int(min(box[0::2])) | |
max_y = int( | |
np.mean(np.array(box[1::2])) + 0.2 * | |
(max(box[1::2]) - min(box[1::2]))) | |
font_scale = get_optimal_font_scale( | |
text, int(max(box[0::2]) - min(box[0::2]))) | |
cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX, | |
font_scale, (0, 0, 0), 1) | |
return out_img | |
def draw_texts_by_pil(img, | |
texts, | |
boxes=None, | |
draw_box=True, | |
on_ori_img=False, | |
font_size=None, | |
fill_color=None, | |
draw_pos=None, | |
return_text_size=False): | |
"""Draw boxes and texts on empty image, especially for Chinese. | |
Args: | |
img (np.ndarray): The original image. | |
texts (list[str]): Recognized texts. | |
boxes (list[list[float]]): Detected bounding boxes. | |
draw_box (bool): Whether draw box or not. If False, draw text only. | |
on_ori_img (bool): If True, draw box and text on input image, | |
else on a new empty image. | |
font_size (int, optional): Size to create a font object for a font. | |
fill_color (tuple(int), optional): Fill color for text. | |
draw_pos (list[tuple(int)], optional): Start point to draw each text. | |
return_text_size (bool): If True, return the list of text size. | |
Returns: | |
(np.ndarray, list[tuple]) or np.ndarray: Return a tuple | |
``(out_img, text_sizes)``, where ``out_img`` is the output image | |
with texts drawn on it and ``text_sizes`` are the size of drawing | |
texts. If ``return_text_size`` is False, only the output image will be | |
returned. | |
""" | |
color_list = gen_color() | |
h, w = img.shape[:2] | |
if boxes is None: | |
boxes = [[0, 0, w, 0, w, h, 0, h]] | |
if draw_pos is None: | |
draw_pos = [None for _ in texts] | |
assert len(boxes) == len(texts) == len(draw_pos) | |
if fill_color is None: | |
fill_color = (0, 0, 0) | |
if on_ori_img: | |
out_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
else: | |
out_img = Image.new('RGB', (w, h), color=(255, 255, 255)) | |
out_draw = ImageDraw.Draw(out_img) | |
text_sizes = [] | |
for idx, (box, text, ori_point) in enumerate(zip(boxes, texts, draw_pos)): | |
if len(text) == 0: | |
continue | |
min_x, max_x = min(box[0::2]), max(box[0::2]) | |
min_y, max_y = min(box[1::2]), max(box[1::2]) | |
color = tuple(list(color_list[idx % len(color_list)])[::-1]) | |
if draw_box: | |
out_draw.line(box, fill=color, width=1) | |
dirname, _ = os.path.split(os.path.abspath(__file__)) | |
font_path = os.path.join(dirname, 'font.TTF') | |
if not os.path.exists(font_path): | |
url = ('https://download.openmmlab.com/mmocr/data/font.TTF') | |
print(f'Downloading {url} ...') | |
local_filename, _ = urllib.request.urlretrieve(url) | |
shutil.move(local_filename, font_path) | |
tmp_font_size = font_size | |
if tmp_font_size is None: | |
box_width = max(max_x - min_x, max_y - min_y) | |
tmp_font_size = int(0.9 * box_width / len(text)) | |
fnt = ImageFont.truetype(font_path, tmp_font_size) | |
if ori_point is None: | |
ori_point = (min_x + 1, min_y + 1) | |
out_draw.text(ori_point, text, font=fnt, fill=fill_color) | |
text_sizes.append(fnt.getsize(text)) | |
del out_draw | |
out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR) | |
if return_text_size: | |
return out_img, text_sizes | |
return out_img | |
def is_contain_chinese(check_str): | |
"""Check whether string contains Chinese or not. | |
Args: | |
check_str (str): String to be checked. | |
Return True if contains Chinese, else False. | |
""" | |
for ch in check_str: | |
if u'\u4e00' <= ch <= u'\u9fff': | |
return True | |
return False | |
def det_recog_show_result(img, end2end_res, out_file=None): | |
"""Draw `result`(boxes and texts) on `img`. | |
Args: | |
img (str or np.ndarray): The image to be displayed. | |
end2end_res (dict): Text detect and recognize results. | |
out_file (str): Image path where the visualized image should be saved. | |
Return: | |
out_img (np.ndarray): Visualized image. | |
""" | |
img = mmcv.imread(img) | |
boxes, texts = [], [] | |
for res in end2end_res['result']: | |
boxes.append(res['box']) | |
texts.append(res['text']) | |
box_vis_img = draw_polygons(img, boxes) | |
if is_contain_chinese(''.join(texts)): | |
text_vis_img = draw_texts_by_pil(img, texts, boxes) | |
else: | |
text_vis_img = draw_texts(img, texts, boxes) | |
h, w = img.shape[:2] | |
out_img = np.ones((h, w * 2, 3), dtype=np.uint8) | |
out_img[:, :w, :] = box_vis_img | |
out_img[:, w:, :] = text_vis_img | |
if out_file: | |
mmcv.imwrite(out_img, out_file) | |
return out_img | |
def draw_edge_result(img, result, edge_thresh=0.5, keynode_thresh=0.5): | |
"""Draw text and their relationship on empty images. | |
Args: | |
img (np.ndarray): The original image. | |
result (dict): The result of model forward_test, including: | |
- img_metas (list[dict]): List of meta information dictionary. | |
- nodes (Tensor): Node prediction with size: | |
number_node * node_classes. | |
- edges (Tensor): Edge prediction with size: number_edge * 2. | |
edge_thresh (float): Score threshold for edge classification. | |
keynode_thresh (float): Score threshold for node | |
(``key``) classification. | |
Returns: | |
np.ndarray: The image with key, value and relation drawn on it. | |
""" | |
h, w = img.shape[:2] | |
vis_area_width = w // 3 * 2 | |
vis_area_height = h | |
dist_key_to_value = vis_area_width // 2 | |
dist_pair_to_pair = 30 | |
bbox_x1 = dist_pair_to_pair | |
bbox_y1 = 0 | |
new_w = vis_area_width | |
new_h = vis_area_height | |
pred_edge_img = np.ones((new_h, new_w, 3), dtype=np.uint8) * 255 | |
nodes = result['nodes'].detach().cpu() | |
texts = result['img_metas'][0]['ori_texts'] | |
num_nodes = result['nodes'].size(0) | |
edges = result['edges'].detach().cpu()[:, -1].view(num_nodes, num_nodes) | |
# (i, j) will be a valid pair | |
# either edge_score(node_i->node_j) > edge_thresh | |
# or edge_score(node_j->node_i) > edge_thresh | |
pairs = (torch.max(edges, edges.T) > edge_thresh).nonzero(as_tuple=True) | |
pairs = (pairs[0].numpy().tolist(), pairs[1].numpy().tolist()) | |
# 1. "for n1, n2 in zip(*pairs) if n1 < n2": | |
# Only (n1, n2) will be included if n1 < n2 but not (n2, n1), to | |
# avoid duplication. | |
# 2. "(n1, n2) if nodes[n1, 1] > nodes[n1, 2]": | |
# nodes[n1, 1] is the score that this node is predicted as key, | |
# nodes[n1, 2] is the score that this node is predicted as value. | |
# If nodes[n1, 1] > nodes[n1, 2], n1 will be the index of key, | |
# so that n2 will be the index of value. | |
result_pairs = [(n1, n2) if nodes[n1, 1] > nodes[n1, 2] else (n2, n1) | |
for n1, n2 in zip(*pairs) if n1 < n2] | |
result_pairs.sort() | |
result_pairs_score = [ | |
torch.max(edges[n1, n2], edges[n2, n1]) for n1, n2 in result_pairs | |
] | |
key_current_idx = -1 | |
pos_current = (-1, -1) | |
newline_flag = False | |
key_font_size = 15 | |
value_font_size = 15 | |
key_font_color = (0, 0, 0) | |
value_font_color = (0, 0, 255) | |
arrow_color = (0, 0, 255) | |
score_color = (0, 255, 0) | |
for pair, pair_score in zip(result_pairs, result_pairs_score): | |
key_idx = pair[0] | |
if nodes[key_idx, 1] < keynode_thresh: | |
continue | |
if key_idx != key_current_idx: | |
# move y-coords down for a new key | |
bbox_y1 += 10 | |
# enlarge blank area to show key-value info | |
if newline_flag: | |
bbox_x1 += vis_area_width | |
tmp_img = np.ones( | |
(new_h, new_w + vis_area_width, 3), dtype=np.uint8) * 255 | |
tmp_img[:new_h, :new_w] = pred_edge_img | |
pred_edge_img = tmp_img | |
new_w += vis_area_width | |
newline_flag = False | |
bbox_y1 = 10 | |
key_text = texts[key_idx] | |
key_pos = (bbox_x1, bbox_y1) | |
value_idx = pair[1] | |
value_text = texts[value_idx] | |
value_pos = (bbox_x1 + dist_key_to_value, bbox_y1) | |
if key_idx != key_current_idx: | |
# draw text for a new key | |
key_current_idx = key_idx | |
pred_edge_img, text_sizes = draw_texts_by_pil( | |
pred_edge_img, [key_text], | |
draw_box=False, | |
on_ori_img=True, | |
font_size=key_font_size, | |
fill_color=key_font_color, | |
draw_pos=[key_pos], | |
return_text_size=True) | |
pos_right_bottom = (key_pos[0] + text_sizes[0][0], | |
key_pos[1] + text_sizes[0][1]) | |
pos_current = (pos_right_bottom[0] + 5, bbox_y1 + 10) | |
pred_edge_img = cv2.arrowedLine( | |
pred_edge_img, (pos_right_bottom[0] + 5, bbox_y1 + 10), | |
(bbox_x1 + dist_key_to_value - 5, bbox_y1 + 10), arrow_color, | |
1) | |
score_pos_x = int( | |
(pos_right_bottom[0] + bbox_x1 + dist_key_to_value) / 2.) | |
score_pos_y = bbox_y1 + 10 - int(key_font_size * 0.3) | |
else: | |
# draw arrow from key to value | |
if newline_flag: | |
tmp_img = np.ones((new_h + dist_pair_to_pair, new_w, 3), | |
dtype=np.uint8) * 255 | |
tmp_img[:new_h, :new_w] = pred_edge_img | |
pred_edge_img = tmp_img | |
new_h += dist_pair_to_pair | |
pred_edge_img = cv2.arrowedLine(pred_edge_img, pos_current, | |
(bbox_x1 + dist_key_to_value - 5, | |
bbox_y1 + 10), arrow_color, 1) | |
score_pos_x = int( | |
(pos_current[0] + bbox_x1 + dist_key_to_value - 5) / 2.) | |
score_pos_y = int((pos_current[1] + bbox_y1 + 10) / 2.) | |
# draw edge score | |
cv2.putText(pred_edge_img, '{:.2f}'.format(pair_score), | |
(score_pos_x, score_pos_y), cv2.FONT_HERSHEY_COMPLEX, 0.4, | |
score_color) | |
# draw text for value | |
pred_edge_img = draw_texts_by_pil( | |
pred_edge_img, [value_text], | |
draw_box=False, | |
on_ori_img=True, | |
font_size=value_font_size, | |
fill_color=value_font_color, | |
draw_pos=[value_pos], | |
return_text_size=False) | |
bbox_y1 += dist_pair_to_pair | |
if bbox_y1 + dist_pair_to_pair >= new_h: | |
newline_flag = True | |
return pred_edge_img | |
def imshow_edge(img, | |
result, | |
boxes, | |
show=False, | |
win_name='', | |
wait_time=-1, | |
out_file=None): | |
"""Display the prediction results of the nodes and edges of the KIE model. | |
Args: | |
img (np.ndarray): The original image. | |
result (dict): The result of model forward_test, including: | |
- img_metas (list[dict]): List of meta information dictionary. | |
- nodes (Tensor): Node prediction with size: \ | |
number_node * node_classes. | |
- edges (Tensor): Edge prediction with size: number_edge * 2. | |
boxes (list): The text boxes corresponding to the nodes. | |
show (bool): Whether to show the image. Default: False. | |
win_name (str): The window name. Default: '' | |
wait_time (float): Value of waitKey param. Default: 0. | |
out_file (str or None): The filename to write the image. | |
Default: None. | |
Returns: | |
np.ndarray: The image with key, value and relation drawn on it. | |
""" | |
img = mmcv.imread(img) | |
h, w = img.shape[:2] | |
color_list = gen_color() | |
for i, box in enumerate(boxes): | |
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], | |
[box[0], box[3]]] | |
Pts = np.array([new_box], np.int32) | |
cv2.polylines( | |
img, [Pts.reshape((-1, 1, 2))], | |
True, | |
color=color_list[i % len(color_list)], | |
thickness=1) | |
pred_img_h = h | |
pred_img_w = w | |
pred_edge_img = draw_edge_result(img, result) | |
pred_img_h = max(pred_img_h, pred_edge_img.shape[0]) | |
pred_img_w += pred_edge_img.shape[1] | |
vis_img = np.zeros((pred_img_h, pred_img_w, 3), dtype=np.uint8) | |
vis_img[:h, :w] = img | |
vis_img[:, w:] = 255 | |
height_t, width_t = pred_edge_img.shape[:2] | |
vis_img[:height_t, w:(w + width_t)] = pred_edge_img | |
if show: | |
mmcv.imshow(vis_img, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(vis_img, out_file) | |
res_dic = { | |
'boxes': boxes, | |
'nodes': result['nodes'].detach().cpu(), | |
'edges': result['edges'].detach().cpu(), | |
'metas': result['img_metas'][0] | |
} | |
mmcv.dump(res_dic, f'{out_file}_res.pkl') | |
return vis_img | |