Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import warnings | |
from os import path as osp | |
import numpy as np | |
import torch | |
from mmdet.datasets.builder import DATASETS | |
from mmocr.core import compute_f1_score | |
from mmocr.datasets.base_dataset import BaseDataset | |
from mmocr.datasets.pipelines import sort_vertex8 | |
from mmocr.utils import is_type_list, list_from_file | |
class KIEDataset(BaseDataset): | |
""" | |
Args: | |
ann_file (str): Annotation file path. | |
pipeline (list[dict]): Processing pipeline. | |
loader (dict): Dictionary to construct loader | |
to load annotation infos. | |
img_prefix (str, optional): Image prefix to generate full | |
image path. | |
test_mode (bool, optional): If True, try...except will | |
be turned off in __getitem__. | |
dict_file (str): Character dict file path. | |
norm (float): Norm to map value from one range to another. | |
""" | |
def __init__(self, | |
ann_file=None, | |
loader=None, | |
dict_file=None, | |
img_prefix='', | |
pipeline=None, | |
norm=10., | |
directed=False, | |
test_mode=True, | |
**kwargs): | |
if ann_file is None and loader is None: | |
warnings.warn( | |
'KIEDataset is only initialized as a downstream demo task ' | |
'of text detection and recognition ' | |
'without an annotation file.', UserWarning) | |
else: | |
super().__init__( | |
ann_file, | |
loader, | |
pipeline, | |
img_prefix=img_prefix, | |
test_mode=test_mode) | |
assert osp.exists(dict_file) | |
self.norm = norm | |
self.directed = directed | |
self.dict = { | |
'': 0, | |
**{ | |
line.rstrip('\r\n'): ind | |
for ind, line in enumerate(list_from_file(dict_file), 1) | |
} | |
} | |
def pre_pipeline(self, results): | |
results['img_prefix'] = self.img_prefix | |
results['bbox_fields'] = [] | |
results['ori_texts'] = results['ann_info']['ori_texts'] | |
results['filename'] = osp.join(self.img_prefix, | |
results['img_info']['filename']) | |
results['ori_filename'] = results['img_info']['filename'] | |
# a dummy img data | |
results['img'] = np.zeros((0, 0, 0), dtype=np.uint8) | |
def _parse_anno_info(self, annotations): | |
"""Parse annotations of boxes, texts and labels for one image. | |
Args: | |
annotations (list[dict]): Annotations of one image, where | |
each dict is for one character. | |
Returns: | |
dict: A dict containing the following keys: | |
- bboxes (np.ndarray): Bbox in one image with shape: | |
box_num * 4. They are sorted clockwise when loading. | |
- relations (np.ndarray): Relations between bbox with shape: | |
box_num * box_num * D. | |
- texts (np.ndarray): Text index with shape: | |
box_num * text_max_len. | |
- labels (np.ndarray): Box Labels with shape: | |
box_num * (box_num + 1). | |
""" | |
assert is_type_list(annotations, dict) | |
assert len(annotations) > 0, 'Please remove data with empty annotation' | |
assert 'box' in annotations[0] | |
assert 'text' in annotations[0] | |
boxes, texts, text_inds, labels, edges = [], [], [], [], [] | |
for ann in annotations: | |
box = ann['box'] | |
sorted_box = sort_vertex8(box[:8]) | |
boxes.append(sorted_box) | |
text = ann['text'] | |
texts.append(ann['text']) | |
text_ind = [self.dict[c] for c in text if c in self.dict] | |
text_inds.append(text_ind) | |
labels.append(ann.get('label', 0)) | |
edges.append(ann.get('edge', 0)) | |
ann_infos = dict( | |
boxes=boxes, | |
texts=texts, | |
text_inds=text_inds, | |
edges=edges, | |
labels=labels) | |
return self.list_to_numpy(ann_infos) | |
def prepare_train_img(self, index): | |
"""Get training data and annotations from pipeline. | |
Args: | |
index (int): Index of data. | |
Returns: | |
dict: Training data and annotation after pipeline with new keys | |
introduced by pipeline. | |
""" | |
img_ann_info = self.data_infos[index] | |
img_info = { | |
'filename': img_ann_info['file_name'], | |
'height': img_ann_info['height'], | |
'width': img_ann_info['width'] | |
} | |
ann_info = self._parse_anno_info(img_ann_info['annotations']) | |
results = dict(img_info=img_info, ann_info=ann_info) | |
self.pre_pipeline(results) | |
return self.pipeline(results) | |
def evaluate(self, | |
results, | |
metric='macro_f1', | |
metric_options=dict(macro_f1=dict(ignores=[])), | |
**kwargs): | |
# allow some kwargs to pass through | |
assert set(kwargs).issubset(['logger']) | |
# Protect ``metric_options`` since it uses mutable value as default | |
metric_options = copy.deepcopy(metric_options) | |
metrics = metric if isinstance(metric, list) else [metric] | |
allowed_metrics = ['macro_f1'] | |
for m in metrics: | |
if m not in allowed_metrics: | |
raise KeyError(f'metric {m} is not supported') | |
return self.compute_macro_f1(results, **metric_options['macro_f1']) | |
def compute_macro_f1(self, results, ignores=[]): | |
node_preds = [] | |
node_gts = [] | |
for idx, result in enumerate(results): | |
node_preds.append(result['nodes'].cpu()) | |
box_ann_infos = self.data_infos[idx]['annotations'] | |
node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos] | |
node_gts.append(torch.Tensor(node_gt)) | |
node_preds = torch.cat(node_preds) | |
node_gts = torch.cat(node_gts).int() | |
node_f1s = compute_f1_score(node_preds, node_gts, ignores) | |
return { | |
'macro_f1': node_f1s.mean(), | |
} | |
def list_to_numpy(self, ann_infos): | |
"""Convert bboxes, relations, texts and labels to ndarray.""" | |
boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds'] | |
texts = ann_infos['texts'] | |
boxes = np.array(boxes, np.int32) | |
relations, bboxes = self.compute_relation(boxes) | |
labels = ann_infos.get('labels', None) | |
if labels is not None: | |
labels = np.array(labels, np.int32) | |
edges = ann_infos.get('edges', None) | |
if edges is not None: | |
labels = labels[:, None] | |
edges = np.array(edges) | |
edges = (edges[:, None] == edges[None, :]).astype(np.int32) | |
if self.directed: | |
edges = (edges & labels == 1).astype(np.int32) | |
np.fill_diagonal(edges, -1) | |
labels = np.concatenate([labels, edges], -1) | |
padded_text_inds = self.pad_text_indices(text_inds) | |
return dict( | |
bboxes=bboxes, | |
relations=relations, | |
texts=padded_text_inds, | |
ori_texts=texts, | |
labels=labels) | |
def pad_text_indices(self, text_inds): | |
"""Pad text index to same length.""" | |
max_len = max([len(text_ind) for text_ind in text_inds]) | |
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) | |
for idx, text_ind in enumerate(text_inds): | |
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind) | |
return padded_text_inds | |
def compute_relation(self, boxes): | |
"""Compute relation between every two boxes.""" | |
# Get minimal axis-aligned bounding boxes for each of the boxes | |
# yapf: disable | |
bboxes = np.concatenate( | |
[boxes[:, 0::2].min(axis=1, keepdims=True), | |
boxes[:, 1::2].min(axis=1, keepdims=True), | |
boxes[:, 0::2].max(axis=1, keepdims=True), | |
boxes[:, 1::2].max(axis=1, keepdims=True)], | |
axis=1).astype(np.float32) | |
# yapf: enable | |
x1, y1 = bboxes[:, 0:1], bboxes[:, 1:2] | |
x2, y2 = bboxes[:, 2:3], bboxes[:, 3:4] | |
w, h = np.maximum(x2 - x1 + 1, 1), np.maximum(y2 - y1 + 1, 1) | |
dx = (x1.T - x1) / self.norm | |
dy = (y1.T - y1) / self.norm | |
xhh, xwh = h.T / h, w.T / h | |
whs = w / h + np.zeros_like(xhh) | |
relation = np.stack([dx, dy, whs, xhh, xwh], -1).astype(np.float32) | |
return relation, bboxes | |