Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import tempfile | |
import unittest | |
import cv2 | |
import numpy as np | |
import torch | |
from mmengine.structures import InstanceData | |
from mmocr.structures import KIEDataSample | |
from mmocr.utils import bbox2poly | |
from mmocr.visualization import KIELocalVisualizer | |
class TestTextKIELocalVisualizer(unittest.TestCase): | |
def setUp(self): | |
h, w = 12, 10 | |
self.image = np.random.randint(0, 256, size=(h, w, 3)).astype('uint8') | |
edge_labels = torch.rand((5, 5)) > 0.5 | |
# gt_instances | |
data_sample = KIEDataSample() | |
gt_instances_data = dict( | |
bboxes=self._rand_bboxes(5, h, w), | |
polygons=self._rand_polys(5, h, w), | |
labels=torch.zeros(5, ), | |
texts=['text1', 'text2', 'text3', 'text4', 'text5'], | |
edge_labels=edge_labels) | |
gt_instances = InstanceData(**gt_instances_data) | |
data_sample.gt_instances = gt_instances | |
pred_instances_data = dict( | |
bboxes=self._rand_bboxes(5, h, w), | |
labels=torch.zeros(5, ), | |
scores=torch.rand((5, )), | |
texts=['text1', 'text2', 'text3', 'text4', 'text5'], | |
edge_labels=edge_labels) | |
pred_instances = InstanceData(**pred_instances_data) | |
data_sample.pred_instances = pred_instances | |
data_sample = data_sample.numpy() | |
self.data_sample = data_sample | |
def _rand_bboxes(num_boxes, h, w): | |
cx, cy, bw, bh = torch.rand(num_boxes, 4).T | |
tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0) | |
tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0) | |
br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0) | |
br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0) | |
bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T | |
return bboxes | |
def _rand_polys(self, num_bboxes, h, w): | |
bboxes = self._rand_bboxes(num_bboxes, h, w) | |
bboxes = bboxes.tolist() | |
polys = [bbox2poly(bbox) for bbox in bboxes] | |
return polys | |
def test_add_datasample(self): | |
image = self.image | |
h, w, c = image.shape | |
visualizer = KIELocalVisualizer(is_openset=True) | |
visualizer.dataset_meta = dict(category=[ | |
dict(id=0, name='bg'), | |
dict(id=1, name='key'), | |
dict(id=2, name='value'), | |
dict(id=3, name='other') | |
]) | |
visualizer.add_datasample('image', image, self.data_sample) | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# test out | |
out_file = osp.join(tmp_dir, 'out_file.jpg') | |
visualizer.add_datasample( | |
'image', | |
image, | |
self.data_sample, | |
out_file=out_file, | |
draw_gt=False, | |
draw_pred=False) | |
self._assert_image_and_shape(out_file, (h, w, c)) | |
visualizer.add_datasample( | |
'image', image, self.data_sample, out_file=out_file) | |
self._assert_image_and_shape(out_file, (h * 2, w * 4, c)) | |
visualizer.add_datasample( | |
'image', | |
image, | |
self.data_sample, | |
draw_gt=False, | |
out_file=out_file) | |
self._assert_image_and_shape(out_file, (h, w * 4, c)) | |
visualizer.add_datasample( | |
'image', | |
image, | |
self.data_sample, | |
draw_pred=False, | |
out_file=out_file) | |
self._assert_image_and_shape(out_file, (h, w * 4, c)) | |
visualizer = KIELocalVisualizer(is_openset=False) | |
visualizer.dataset_meta = dict(category=[ | |
dict(id=0, name='bg'), | |
dict(id=1, name='key'), | |
dict(id=2, name='value'), | |
dict(id=3, name='other') | |
]) | |
visualizer.add_datasample( | |
'image', | |
image, | |
self.data_sample, | |
draw_pred=False, | |
out_file=out_file) | |
self._assert_image_and_shape(out_file, (h, w * 3, c)) | |
def _assert_image_and_shape(self, out_file, out_shape): | |
self.assertTrue(osp.exists(out_file)) | |
drawn_img = cv2.imread(out_file) | |
self.assertTrue(drawn_img.shape == out_shape) | |