Spaces:
Running
Running
File size: 2,629 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple
import torch
from mmdet.models.losses import accuracy
from torch import Tensor, nn
from mmocr.registry import MODELS
from mmocr.structures import KIEDataSample
@MODELS.register_module()
class SDMGRModuleLoss(nn.Module):
"""The implementation the loss of key information extraction proposed in
the paper: `Spatial Dual-Modality Graph Reasoning for Key Information
Extraction <https://arxiv.org/abs/2103.14470>`_.
Args:
weight_node (float): Weight of node loss. Defaults to 1.0.
weight_edge (float): Weight of edge loss. Defaults to 1.0.
ignore_idx (int): Node label to ignore. Defaults to -100.
"""
def __init__(self,
weight_node: float = 1.0,
weight_edge: float = 1.0,
ignore_idx: int = -100) -> None:
super().__init__()
# TODO: Use MODELS.build after DRRG loss has been merged
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore_idx)
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
self.weight_node = weight_node
self.weight_edge = weight_edge
self.ignore_idx = ignore_idx
def forward(self, preds: Tuple[Tensor, Tensor],
data_samples: List[KIEDataSample]) -> Dict:
"""Forward function.
Args:
preds (tuple(Tensor, Tensor)):
data_samples (list[KIEDataSample]): A list of datasamples
containing ``gt_instances.labels`` and
``gt_instances.edge_labels``.
Returns:
dict(str, Tensor): Loss dict, containing ``loss_node``,
``loss_edge``, ``acc_node`` and ``acc_edge``.
"""
node_preds, edge_preds = preds
node_gts, edge_gts = [], []
for data_sample in data_samples:
node_gts.append(data_sample.gt_instances.labels)
edge_gts.append(data_sample.gt_instances.edge_labels.reshape(-1))
node_gts = torch.cat(node_gts).long()
edge_gts = torch.cat(edge_gts).long()
node_valids = torch.nonzero(
node_gts != self.ignore_idx, as_tuple=False).reshape(-1)
edge_valids = torch.nonzero(edge_gts != -1, as_tuple=False).reshape(-1)
return dict(
loss_node=self.weight_node * self.loss_node(node_preds, node_gts),
loss_edge=self.weight_edge * self.loss_edge(edge_preds, edge_gts),
acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))
|