Spaces:
Runtime error
Runtime error
File size: 5,963 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
from mmdet.core import bbox2roi
from torch import nn
from torch.nn import functional as F
from mmocr.core import imshow_edge, imshow_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.utils import list_from_file
@DETECTORS.register_module()
class SDMGR(SingleStageDetector):
"""The implementation of the paper: Spatial Dual-Modality Graph Reasoning
for Key Information Extraction. https://arxiv.org/abs/2103.14470.
Args:
visual_modality (bool): Whether use the visual modality.
class_list (None | str): Mapping file of class index to
class name. If None, class index will be shown in
`show_results`, else class name.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
extractor=dict(
type='mmdet.SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7),
featmap_strides=[1]),
visual_modality=False,
train_cfg=None,
test_cfg=None,
class_list=None,
init_cfg=None,
openset=False):
super().__init__(
backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg)
self.visual_modality = visual_modality
if visual_modality:
self.extractor = build_roi_extractor({
**extractor, 'out_channels':
self.backbone.base_channels
})
self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
else:
self.extractor = None
self.class_list = class_list
self.openset = openset
def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
gt_labels):
"""
Args:
img (tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A list of image info dict where each dict
contains: 'img_shape', 'scale_factor', 'flip', and may also
contain 'filename', 'ori_shape', 'pad_shape', and
'img_norm_cfg'. For details of the values of these keys,
please see :class:`mmdet.datasets.pipelines.Collect`.
relations (list[tensor]): Relations between bboxes.
texts (list[tensor]): Texts in bboxes.
gt_bboxes (list[tensor]): Each item is the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[tensor]): Class indices corresponding to each box.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img, gt_bboxes)
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
def forward_test(self,
img,
img_metas,
relations,
texts,
gt_bboxes,
rescale=False):
x = self.extract_feat(img, gt_bboxes)
node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
return [
dict(
img_metas=img_metas,
nodes=F.softmax(node_preds, -1),
edges=F.softmax(edge_preds, -1))
]
def extract_feat(self, img, gt_bboxes):
if self.visual_modality:
x = super().extract_feat(img)[-1]
feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
return feats.view(feats.size(0), -1)
return None
def show_result(self,
img,
result,
boxes,
win_name='',
show=False,
wait_time=0,
out_file=None,
**kwargs):
"""Draw `result` on `img`.
Args:
img (str or tensor): The image to be displayed.
result (dict): The results to draw on `img`.
boxes (list): Bbox of img.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The output filename.
Default: None.
Returns:
img (tensor): Only if not `show` or `out_file`.
"""
img = mmcv.imread(img)
img = img.copy()
idx_to_cls = {}
if self.class_list is not None:
for line in list_from_file(self.class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[class_idx] = class_label
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if self.openset:
img = imshow_edge(
img,
result,
boxes,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
else:
img = imshow_node(
img,
result,
boxes,
idx_to_cls=idx_to_cls,
show=show,
win_name=win_name,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
return img
|