Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
15.4 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine.model import BaseModule
from torch import Tensor, nn
from torch.nn import functional as F
from mmocr.models.common.dictionary import Dictionary
from mmocr.registry import MODELS, TASK_UTILS
from mmocr.structures import KIEDataSample
@MODELS.register_module()
class SDMGRHead(BaseModule):
"""SDMGR Head.
Args:
dictionary (dict or :obj:`Dictionary`): The config for `Dictionary` or
the instance of `Dictionary`.
num_classes (int): Number of class labels. Defaults to 26.
visual_dim (int): Dimension of visual features :math:`E`. Defaults to
64.
fusion_dim (int): Dimension of fusion layer. Defaults to 1024.
node_input (int): Dimension of raw node embedding. Defaults to 32.
node_embed (int): Dimension of node embedding. Defaults to 256.
edge_input (int): Dimension of raw edge embedding. Defaults to 5.
edge_embed (int): Dimension of edge embedding. Defaults to 256.
num_gnn (int): Number of GNN layers. Defaults to 2.
bidirectional (bool): Whether to use bidirectional RNN to embed nodes.
Defaults to False.
relation_norm (float): Norm to map value from one range to another.=
Defaults to 10.
module_loss (dict): Module Loss config. Defaults to
``dict(type='SDMGRModuleLoss')``.
postprocessor (dict): Postprocessor config. Defaults to
``dict(type='SDMGRPostProcessor')``.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(
self,
dictionary: Union[Dictionary, Dict],
num_classes: int = 26,
visual_dim: int = 64,
fusion_dim: int = 1024,
node_input: int = 32,
node_embed: int = 256,
edge_input: int = 5,
edge_embed: int = 256,
num_gnn: int = 2,
bidirectional: bool = False,
relation_norm: float = 10.,
module_loss: Dict = dict(type='SDMGRModuleLoss'),
postprocessor: Dict = dict(type='SDMGRPostProcessor'),
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
type='Normal', override=dict(name='edge_embed'), mean=0, std=0.01)
) -> None:
super().__init__(init_cfg=init_cfg)
assert isinstance(dictionary, (dict, Dictionary))
if isinstance(dictionary, dict):
self.dictionary = TASK_UTILS.build(dictionary)
elif isinstance(dictionary, Dictionary):
self.dictionary = dictionary
self.fusion = FusionBlock([visual_dim, node_embed], node_embed,
fusion_dim)
self.node_embed = nn.Embedding(self.dictionary.num_classes, node_input,
self.dictionary.padding_idx)
hidden = node_embed // 2 if bidirectional else node_embed
self.rnn = nn.LSTM(
input_size=node_input,
hidden_size=hidden,
num_layers=1,
batch_first=True,
bidirectional=bidirectional)
self.edge_embed = nn.Linear(edge_input, edge_embed)
self.gnn_layers = nn.ModuleList(
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2)
self.module_loss = MODELS.build(module_loss)
self.postprocessor = MODELS.build(postprocessor)
self.relation_norm = relation_norm
def loss(self, inputs: Tensor, data_samples: List[KIEDataSample]) -> Dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (torch.Tensor): Shape :math:`(N, E)`.
data_samples (List[KIEDataSample]): List of data samples.
Returns:
dict[str, tensor]: A dictionary of loss components.
"""
preds = self.forward(inputs, data_samples)
return self.module_loss(preds, data_samples)
def predict(self, inputs: Tensor,
data_samples: List[KIEDataSample]) -> List[KIEDataSample]:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (torch.Tensor): Shape :math:`(N, E)`.
data_samples (List[KIEDataSample]): List of data samples.
Returns:
List[KIEDataSample]: A list of datasamples of prediction results.
Results are stored in ``pred_instances.labels``,
``pred_instances.scores``, ``pred_instances.edge_labels`` and
``pred_instances.edge_scores``.
- labels (Tensor): An integer tensor of shape (N, ) indicating bbox
labels for each image.
- scores (Tensor): A float tensor of shape (N, ), indicating the
confidence scores for node label predictions.
- edge_labels (Tensor): An integer tensor of shape (N, N)
indicating the connection between nodes. Options are 0, 1.
- edge_scores (Tensor): A float tensor of shape (N, ), indicating
the confidence scores for edge predictions.
"""
preds = self.forward(inputs, data_samples)
return self.postprocessor(preds, data_samples)
def forward(self, inputs: Tensor,
data_samples: List[KIEDataSample]) -> Tuple[Tensor, Tensor]:
"""
Args:
inputs (torch.Tensor): Shape :math:`(N, E)`.
data_samples (List[KIEDataSample]): List of data samples.
Returns:
tuple(Tensor, Tensor):
- node_cls (Tensor): Raw logits scores for nodes. Shape
:math:`(N, C_{l})` where :math:`C_{l}` is number of classes.
- edge_cls (Tensor): Raw logits scores for edges. Shape
:math:`(N * N, 2)`.
"""
device = self.node_embed.weight.device
node_nums, char_nums, all_nodes = self.convert_texts(data_samples)
embed_nodes = self.node_embed(all_nodes.to(device).long())
rnn_nodes, _ = self.rnn(embed_nodes)
nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2])
all_nums = torch.cat(char_nums).to(device)
valid = all_nums > 0
nodes[valid] = rnn_nodes[valid].gather(
1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
-1, -1, rnn_nodes.size(-1))).squeeze(1)
if inputs is not None:
nodes = self.fusion([inputs, nodes])
relations = self.compute_relations(data_samples)
all_edges = torch.cat(
[relation.view(-1, relation.size(-1)) for relation in relations],
dim=0)
embed_edges = self.edge_embed(all_edges.float())
embed_edges = F.normalize(embed_edges)
for gnn_layer in self.gnn_layers:
nodes, embed_edges = gnn_layer(nodes, embed_edges, node_nums)
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(embed_edges)
return node_cls, edge_cls
def convert_texts(
self, data_samples: List[KIEDataSample]
) -> Tuple[List[Tensor], List[Tensor], Tensor]:
"""Extract texts in datasamples and pack them into a batch.
Args:
data_samples (List[KIEDataSample]): List of data samples.
Returns:
tuple(List[int], List[Tensor], Tensor):
- node_nums (List[int]): A list of node numbers for each
sample.
- char_nums (List[Tensor]): A list of character numbers for each
sample.
- nodes (Tensor): A tensor of shape :math:`(N, C)` where
:math:`C` is the maximum number of characters in a sample.
"""
node_nums, char_nums = [], []
max_len = -1
text_idxs = []
for data_sample in data_samples:
node_nums.append(len(data_sample.gt_instances.texts))
for text in data_sample.gt_instances.texts:
text_idxs.append(self.dictionary.str2idx(text))
max_len = max(max_len, len(text))
nodes = torch.zeros((sum(node_nums), max_len),
dtype=torch.long) + self.dictionary.padding_idx
for i, text_idx in enumerate(text_idxs):
nodes[i, :len(text_idx)] = torch.LongTensor(text_idx)
char_nums = (nodes != self.dictionary.padding_idx).sum(-1).split(
node_nums, dim=0)
return node_nums, char_nums, nodes
def compute_relations(self, data_samples: List[KIEDataSample]) -> Tensor:
"""Compute the relations between every two boxes for each datasample,
then return the concatenated relations."""
relations = []
for data_sample in data_samples:
bboxes = data_sample.gt_instances.bboxes
x1, y1 = bboxes[:, 0:1], bboxes[:, 1:2]
x2, y2 = bboxes[:, 2:3], bboxes[:, 3:4]
w, h = torch.clamp(
x2 - x1 + 1, min=1), torch.clamp(
y2 - y1 + 1, min=1)
dx = (x1.t() - x1) / self.relation_norm
dy = (y1.t() - y1) / self.relation_norm
xhh, xwh = h.T / h, w.T / h
whs = w / h + torch.zeros_like(xhh)
relation = torch.stack([dx, dy, whs, xhh, xwh], -1).float()
relations.append(relation)
return relations
class GNNLayer(nn.Module):
"""GNN layer for SDMGR.
Args:
node_dim (int): Dimension of node embedding. Defaults to 256.
edge_dim (int): Dimension of edge embedding. Defaults to 256.
"""
def __init__(self, node_dim: int = 256, edge_dim: int = 256) -> None:
super().__init__()
self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
self.coef_fc = nn.Linear(node_dim, 1)
self.out_fc = nn.Linear(node_dim, node_dim)
self.relu = nn.ReLU()
def forward(self, nodes: Tensor, edges: Tensor,
nums: List[int]) -> Tuple[Tensor, Tensor]:
"""Forward function.
Args:
nodes (Tensor): Concatenated node embeddings.
edges (Tensor): Concatenated edge embeddings.
nums (List[int]): List of number of nodes in each batch.
Returns:
tuple(Tensor, Tensor):
- nodes (Tensor): New node embeddings.
- edges (Tensor): New edge embeddings.
"""
start, cat_nodes = 0, []
for num in nums:
sample_nodes = nodes[start:start + num]
cat_nodes.append(
torch.cat([
sample_nodes.unsqueeze(1).expand(-1, num, -1),
sample_nodes.unsqueeze(0).expand(num, -1, -1)
], -1).view(num**2, -1))
start += num
cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1)
cat_nodes = self.relu(self.in_fc(cat_nodes))
coefs = self.coef_fc(cat_nodes)
start, residuals = 0, []
for num in nums:
residual = F.softmax(
-torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 +
coefs[start:start + num**2].view(num, num, -1), 1)
residuals.append(
(residual *
cat_nodes[start:start + num**2].view(num, num, -1)).sum(1))
start += num**2
nodes += self.relu(self.out_fc(torch.cat(residuals)))
return nodes, cat_nodes
class FusionBlock(nn.Module):
"""Fusion block of SDMGR.
Args:
input_dims (tuple(int, int)): Visual dimension and node embedding
dimension.
output_dim (int): Output dimension.
mm_dim (int): Model dimension. Defaults to 1600.
chunks (int): Number of chunks. Defaults to 20.
rank (int): Rank number. Defaults to 15.
shared (bool): Whether to share the project layer between visual and
node embedding features. Defaults to False.
dropout_input (float): Dropout rate after the first projection layer.
Defaults to 0.
dropout_pre_lin (float): Dropout rate before the final project layer.
Defaults to 0.
dropout_pre_lin (float): Dropout rate after the final project layer.
Defaults to 0.
pos_norm (str): The normalization position. Options are 'before_cat'
and 'after_cat'. Defaults to 'before_cat'.
"""
def __init__(self,
input_dims: Tuple[int, int],
output_dim: int,
mm_dim: int = 1600,
chunks: int = 20,
rank: int = 15,
shared: bool = False,
dropout_input: float = 0.,
dropout_pre_lin: float = 0.,
dropout_output: float = 0.,
pos_norm: str = 'before_cat') -> None:
super().__init__()
self.rank = rank
self.dropout_input = dropout_input
self.dropout_pre_lin = dropout_pre_lin
self.dropout_output = dropout_output
assert (pos_norm in ['before_cat', 'after_cat'])
self.pos_norm = pos_norm
# Modules
self.linear0 = nn.Linear(input_dims[0], mm_dim)
self.linear1 = (
self.linear0 if shared else nn.Linear(input_dims[1], mm_dim))
self.merge_linears0 = nn.ModuleList()
self.merge_linears1 = nn.ModuleList()
self.chunks = self.chunk_sizes(mm_dim, chunks)
for size in self.chunks:
ml0 = nn.Linear(size, size * rank)
self.merge_linears0.append(ml0)
ml1 = ml0 if shared else nn.Linear(size, size * rank)
self.merge_linears1.append(ml1)
self.linear_out = nn.Linear(mm_dim, output_dim)
def forward(self, x: Tensor) -> Tensor:
"""Forward function."""
x0 = self.linear0(x[0])
x1 = self.linear1(x[1])
bs = x1.size(0)
if self.dropout_input > 0:
x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
x0_chunks = torch.split(x0, self.chunks, -1)
x1_chunks = torch.split(x1, self.chunks, -1)
zs = []
for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks,
self.merge_linears0,
self.merge_linears1):
m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
m = m.view(bs, self.rank, -1)
z = torch.sum(m, 1)
if self.pos_norm == 'before_cat':
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
z = F.normalize(z)
zs.append(z)
z = torch.cat(zs, 1)
if self.pos_norm == 'after_cat':
z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
z = F.normalize(z)
if self.dropout_pre_lin > 0:
z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
z = self.linear_out(z)
if self.dropout_output > 0:
z = F.dropout(z, p=self.dropout_output, training=self.training)
return z
@staticmethod
def chunk_sizes(dim: int, chunks: int) -> List[int]:
"""Compute chunk sizes."""
split_size = (dim + chunks - 1) // chunks
sizes_list = [split_size] * chunks
sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
return sizes_list