Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional, Sequence, Tuple, Union | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
try: | |
from lanms import merge_quadrangle_n9 as la_nms | |
except ImportError: | |
la_nms = None | |
from mmcv.ops import RoIAlignRotated | |
from mmengine.model import BaseModule | |
from numpy import ndarray | |
from torch import Tensor | |
from torch.nn import init | |
from mmocr.models.textdet.heads import BaseTextDetHead | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextDetDataSample | |
from mmocr.utils import fill_hole | |
def normalize_adjacent_matrix(mat: ndarray) -> ndarray: | |
"""Normalize adjacent matrix for GCN. This code was partially adapted from | |
https://github.com/GXYM/DRRG licensed under the MIT license. | |
Args: | |
mat (ndarray): The adjacent matrix. | |
returns: | |
ndarray: The normalized adjacent matrix. | |
""" | |
assert mat.ndim == 2 | |
assert mat.shape[0] == mat.shape[1] | |
mat = mat + np.eye(mat.shape[0]) | |
d = np.sum(mat, axis=0) | |
d = np.clip(d, 0, None) | |
d_inv = np.power(d, -0.5).flatten() | |
d_inv[np.isinf(d_inv)] = 0.0 | |
d_inv = np.diag(d_inv) | |
norm_mat = mat.dot(d_inv).transpose().dot(d_inv) | |
return norm_mat | |
def euclidean_distance_matrix(mat_a: ndarray, mat_b: ndarray) -> ndarray: | |
"""Calculate the Euclidean distance matrix. | |
Args: | |
mat_a (ndarray): The point sequence. | |
mat_b (ndarray): The point sequence with the same dimensions as mat_a. | |
returns: | |
ndarray: The Euclidean distance matrix. | |
""" | |
assert mat_a.ndim == 2 | |
assert mat_b.ndim == 2 | |
assert mat_a.shape[1] == mat_b.shape[1] | |
m = mat_a.shape[0] | |
n = mat_b.shape[0] | |
mat_a_dots = (mat_a * mat_a).sum(axis=1).reshape( | |
(m, 1)) * np.ones(shape=(1, n)) | |
mat_b_dots = (mat_b * mat_b).sum(axis=1) * np.ones(shape=(m, 1)) | |
mat_d_squared = mat_a_dots + mat_b_dots - 2 * mat_a.dot(mat_b.T) | |
zero_mask = np.less(mat_d_squared, 0.0) | |
mat_d_squared[zero_mask] = 0.0 | |
mat_d = np.sqrt(mat_d_squared) | |
return mat_d | |
def feature_embedding(input_feats: ndarray, out_feat_len: int) -> ndarray: | |
"""Embed features. This code was partially adapted from | |
https://github.com/GXYM/DRRG licensed under the MIT license. | |
Args: | |
input_feats (ndarray): The input features of shape (N, d), where N is | |
the number of nodes in graph, d is the input feature vector length. | |
out_feat_len (int): The length of output feature vector. | |
Returns: | |
ndarray: The embedded features. | |
""" | |
assert input_feats.ndim == 2 | |
assert isinstance(out_feat_len, int) | |
assert out_feat_len >= input_feats.shape[1] | |
num_nodes = input_feats.shape[0] | |
feat_dim = input_feats.shape[1] | |
feat_repeat_times = out_feat_len // feat_dim | |
residue_dim = out_feat_len % feat_dim | |
if residue_dim > 0: | |
embed_wave = np.array([ | |
np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1) | |
for j in range(feat_repeat_times + 1) | |
]).reshape((feat_repeat_times + 1, 1, 1)) | |
repeat_feats = np.repeat( | |
np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0) | |
residue_feats = np.hstack([ | |
input_feats[:, 0:residue_dim], | |
np.zeros((num_nodes, feat_dim - residue_dim)) | |
]) | |
residue_feats = np.expand_dims(residue_feats, axis=0) | |
repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0) | |
embedded_feats = repeat_feats / embed_wave | |
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2]) | |
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2]) | |
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape( | |
(num_nodes, -1))[:, 0:out_feat_len] | |
else: | |
embed_wave = np.array([ | |
np.power(1000, 2.0 * (j // 2) / feat_repeat_times) | |
for j in range(feat_repeat_times) | |
]).reshape((feat_repeat_times, 1, 1)) | |
repeat_feats = np.repeat( | |
np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0) | |
embedded_feats = repeat_feats / embed_wave | |
embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2]) | |
embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2]) | |
embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape( | |
(num_nodes, -1)).astype(np.float32) | |
return embedded_feats | |
class DRRGHead(BaseTextDetHead): | |
"""The class for DRRG head: `Deep Relational Reasoning Graph Network for | |
Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_. | |
Args: | |
in_channels (int): The number of input channels. | |
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. | |
Defaults to (8, 4). | |
num_adjacent_linkages (int): The number of linkages when constructing | |
adjacent matrix. Defaults to 3. | |
node_geo_feat_len (int): The length of embedded geometric feature | |
vector of a component. Defaults to 120. | |
pooling_scale (float): The spatial scale of rotated RoI-Align. Defaults | |
to 1.0. | |
pooling_output_size (tuple(int)): The output size of RRoI-Aligning. | |
Defaults to (4, 3). | |
nms_thr (float): The locality-aware NMS threshold of text components. | |
Defaults to 0.3. | |
min_width (float): The minimum width of text components. Defaults to | |
8.0. | |
max_width (float): The maximum width of text components. Defaults to | |
24.0. | |
comp_shrink_ratio (float): The shrink ratio of text components. | |
Defaults to 1.03. | |
comp_ratio (float): The reciprocal of aspect ratio of text components. | |
Defaults to 0.4. | |
comp_score_thr (float): The score threshold of text components. | |
Defaults to 0.3. | |
text_region_thr (float): The threshold for text region probability map. | |
Defaults to 0.2. | |
center_region_thr (float): The threshold for text center region | |
probability map. Defaults to 0.2. | |
center_region_area_thr (int): The threshold for filtering small-sized | |
text center region. Defaults to 50. | |
local_graph_thr (float): The threshold to filter identical local | |
graphs. Defaults to 0.7. | |
module_loss (dict): The config of loss that DRRGHead uses. Defaults to | |
``dict(type='DRRGModuleLoss')``. | |
postprocessor (dict): Config of postprocessor for Drrg. Defaults to | |
``dict(type='DrrgPostProcessor', link_thr=0.85)``. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
Defaults to ``dict(type='Normal', | |
override=dict(name='out_conv'), mean=0, std=0.01)``. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
k_at_hops: Tuple[int, int] = (8, 4), | |
num_adjacent_linkages: int = 3, | |
node_geo_feat_len: int = 120, | |
pooling_scale: float = 1.0, | |
pooling_output_size: Tuple[int, int] = (4, 3), | |
nms_thr: float = 0.3, | |
min_width: float = 8.0, | |
max_width: float = 24.0, | |
comp_shrink_ratio: float = 1.03, | |
comp_ratio: float = 0.4, | |
comp_score_thr: float = 0.3, | |
text_region_thr: float = 0.2, | |
center_region_thr: float = 0.2, | |
center_region_area_thr: int = 50, | |
local_graph_thr: float = 0.7, | |
module_loss: Dict = dict(type='DRRGModuleLoss'), | |
postprocessor: Dict = dict(type='DRRGPostprocessor', link_thr=0.85), | |
init_cfg: Optional[Union[Dict, List[Dict]]] = dict( | |
type='Normal', override=dict(name='out_conv'), mean=0, std=0.01) | |
) -> None: | |
super().__init__( | |
module_loss=module_loss, | |
postprocessor=postprocessor, | |
init_cfg=init_cfg) | |
assert isinstance(in_channels, int) | |
assert isinstance(k_at_hops, tuple) | |
assert isinstance(num_adjacent_linkages, int) | |
assert isinstance(node_geo_feat_len, int) | |
assert isinstance(pooling_scale, float) | |
assert isinstance(pooling_output_size, tuple) | |
assert isinstance(comp_shrink_ratio, float) | |
assert isinstance(nms_thr, float) | |
assert isinstance(min_width, float) | |
assert isinstance(max_width, float) | |
assert isinstance(comp_ratio, float) | |
assert isinstance(comp_score_thr, float) | |
assert isinstance(text_region_thr, float) | |
assert isinstance(center_region_thr, float) | |
assert isinstance(center_region_area_thr, int) | |
assert isinstance(local_graph_thr, float) | |
self.in_channels = in_channels | |
self.out_channels = 6 | |
self.downsample_ratio = 1.0 | |
self.k_at_hops = k_at_hops | |
self.num_adjacent_linkages = num_adjacent_linkages | |
self.node_geo_feat_len = node_geo_feat_len | |
self.pooling_scale = pooling_scale | |
self.pooling_output_size = pooling_output_size | |
self.comp_shrink_ratio = comp_shrink_ratio | |
self.nms_thr = nms_thr | |
self.min_width = min_width | |
self.max_width = max_width | |
self.comp_ratio = comp_ratio | |
self.comp_score_thr = comp_score_thr | |
self.text_region_thr = text_region_thr | |
self.center_region_thr = center_region_thr | |
self.center_region_area_thr = center_region_area_thr | |
self.local_graph_thr = local_graph_thr | |
self.out_conv = nn.Conv2d( | |
in_channels=self.in_channels, | |
out_channels=self.out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
self.graph_train = LocalGraphs(self.k_at_hops, | |
self.num_adjacent_linkages, | |
self.node_geo_feat_len, | |
self.pooling_scale, | |
self.pooling_output_size, | |
self.local_graph_thr) | |
self.graph_test = ProposalLocalGraphs( | |
self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len, | |
self.pooling_scale, self.pooling_output_size, self.nms_thr, | |
self.min_width, self.max_width, self.comp_shrink_ratio, | |
self.comp_ratio, self.comp_score_thr, self.text_region_thr, | |
self.center_region_thr, self.center_region_area_thr) | |
pool_w, pool_h = self.pooling_output_size | |
node_feat_len = (pool_w * pool_h) * ( | |
self.in_channels + self.out_channels) + self.node_geo_feat_len | |
self.gcn = GCN(node_feat_len) | |
def loss(self, inputs: torch.Tensor, data_samples: List[TextDetDataSample] | |
) -> Tuple[Tensor, Tensor, Tensor]: | |
"""Loss function. | |
Args: | |
inputs (Tensor): Shape of :math:`(N, C, H, W)`. | |
data_samples (List[TextDetDataSample]): List of data samples. | |
Returns: | |
tuple(pred_maps, gcn_pred, gt_labels): | |
- pred_maps (Tensor): Prediction map with shape | |
:math:`(N, 6, H, W)`. | |
- gcn_pred (Tensor): Prediction from GCN module, with | |
shape :math:`(N, 2)`. | |
- gt_labels (Tensor): Ground-truth label of shape | |
:math:`(m, n)` where :math:`m * n = N`. | |
""" | |
targets = self.module_loss.get_targets(data_samples) | |
gt_comp_attribs = targets[-1] | |
pred_maps = self.out_conv(inputs) | |
feat_maps = torch.cat([inputs, pred_maps], dim=1) | |
node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train( | |
feat_maps, np.stack(gt_comp_attribs)) | |
gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds) | |
return self.module_loss((pred_maps, gcn_pred, gt_labels), data_samples) | |
def forward( | |
self, | |
inputs: Tensor, | |
data_samples: Optional[List[TextDetDataSample]] = None | |
) -> Tuple[Tensor, Tensor, Tensor]: | |
r"""Run DRRG head in prediction mode, and return the raw tensors only. | |
Args: | |
inputs (Tensor): Shape of :math:`(1, C, H, W)`. | |
data_samples (list[TextDetDataSample], optional): A list of data | |
samples. Defaults to None. | |
Returns: | |
tuple: Returns (edge, score, text_comps). | |
- edge (ndarray): The edge array of shape :math:`(N_{edges}, 2)` | |
where each row is a pair of text component indices | |
that makes up an edge in graph. | |
- score (ndarray): The score array of shape :math:`(N_{edges},)`, | |
corresponding to the edge above. | |
- text_comps (ndarray): The text components of shape | |
:math:`(M, 9)` where each row corresponds to one box and | |
its score: (x1, y1, x2, y2, x3, y3, x4, y4, score). | |
""" | |
pred_maps = self.out_conv(inputs) | |
inputs = torch.cat([inputs, pred_maps], dim=1) | |
none_flag, graph_data = self.graph_test(pred_maps, inputs) | |
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
pivot_local_graphs, text_comps) = graph_data | |
if none_flag: | |
return None, None, None | |
gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices, | |
pivots_knn_inds) | |
pred_labels = F.softmax(gcn_pred, dim=1) | |
edges = [] | |
scores = [] | |
pivot_local_graphs = pivot_local_graphs.long().squeeze().cpu().numpy() | |
for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs): | |
pivot = pivot_local_graph[0] | |
for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]): | |
neighbor = pivot_local_graph[neighbor_ind.item()] | |
edges.append([pivot, neighbor]) | |
scores.append( | |
pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind, | |
1].item()) | |
edges = np.asarray(edges) | |
scores = np.asarray(scores) | |
return edges, scores, text_comps | |
class LocalGraphs: | |
"""Generate local graphs for GCN to classify the neighbors of a pivot for | |
`DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text | |
Detection <[https://arxiv.org/abs/2003.07493]>`_. | |
This code was partially adapted from | |
https://github.com/GXYM/DRRG licensed under the MIT license. | |
Args: | |
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. | |
num_adjacent_linkages (int): The number of linkages when constructing | |
adjacent matrix. | |
node_geo_feat_len (int): The length of embedded geometric feature | |
vector of a text component. | |
pooling_scale (float): The spatial scale of rotated RoI-Align. | |
pooling_output_size (tuple(int)): The output size of rotated RoI-Align. | |
local_graph_thr(float): The threshold for filtering out identical local | |
graphs. | |
""" | |
def __init__(self, k_at_hops: Tuple[int, int], num_adjacent_linkages: int, | |
node_geo_feat_len: int, pooling_scale: float, | |
pooling_output_size: Sequence[int], | |
local_graph_thr: float) -> None: | |
assert len(k_at_hops) == 2 | |
assert all(isinstance(n, int) for n in k_at_hops) | |
assert isinstance(num_adjacent_linkages, int) | |
assert isinstance(node_geo_feat_len, int) | |
assert isinstance(pooling_scale, float) | |
assert all(isinstance(n, int) for n in pooling_output_size) | |
assert isinstance(local_graph_thr, float) | |
self.k_at_hops = k_at_hops | |
self.num_adjacent_linkages = num_adjacent_linkages | |
self.node_geo_feat_dim = node_geo_feat_len | |
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale) | |
self.local_graph_thr = local_graph_thr | |
def generate_local_graphs(self, sorted_dist_inds: ndarray, | |
gt_comp_labels: ndarray | |
) -> Tuple[List[List[int]], List[List[int]]]: | |
"""Generate local graphs for GCN to predict which instance a text | |
component belongs to. | |
Args: | |
sorted_dist_inds (ndarray): The complete graph node indices, which | |
is sorted according to the Euclidean distance. | |
gt_comp_labels(ndarray): The ground truth labels define the | |
instance to which the text components (nodes in graphs) belong. | |
Returns: | |
Tuple(pivot_local_graphs, pivot_knns): | |
- pivot_local_graphs (list[list[int]]): The list of local graph | |
neighbor indices of pivots. | |
- pivot_knns (list[list[int]]): The list of k-nearest neighbor | |
indices of pivots. | |
""" | |
assert sorted_dist_inds.ndim == 2 | |
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] == | |
gt_comp_labels.shape[0]) | |
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1] | |
pivot_local_graphs = [] | |
pivot_knns = [] | |
for pivot_ind, knn in enumerate(knn_graph): | |
local_graph_neighbors = set(knn) | |
for neighbor_ind in knn: | |
local_graph_neighbors.update( | |
set(sorted_dist_inds[neighbor_ind, | |
1:self.k_at_hops[1] + 1])) | |
local_graph_neighbors.discard(pivot_ind) | |
pivot_local_graph = list(local_graph_neighbors) | |
pivot_local_graph.insert(0, pivot_ind) | |
pivot_knn = [pivot_ind] + list(knn) | |
if pivot_ind < 1: | |
pivot_local_graphs.append(pivot_local_graph) | |
pivot_knns.append(pivot_knn) | |
else: | |
add_flag = True | |
for graph_ind, added_knn in enumerate(pivot_knns): | |
added_pivot_ind = added_knn[0] | |
added_local_graph = pivot_local_graphs[graph_ind] | |
union = len( | |
set(pivot_local_graph[1:]).union( | |
set(added_local_graph[1:]))) | |
intersect = len( | |
set(pivot_local_graph[1:]).intersection( | |
set(added_local_graph[1:]))) | |
local_graph_iou = intersect / (union + 1e-8) | |
if (local_graph_iou > self.local_graph_thr | |
and pivot_ind in added_knn | |
and gt_comp_labels[added_pivot_ind] | |
== gt_comp_labels[pivot_ind] | |
and gt_comp_labels[pivot_ind] != 0): | |
add_flag = False | |
break | |
if add_flag: | |
pivot_local_graphs.append(pivot_local_graph) | |
pivot_knns.append(pivot_knn) | |
return pivot_local_graphs, pivot_knns | |
def generate_gcn_input( | |
self, node_feat_batch: List[Tensor], node_label_batch: List[ndarray], | |
local_graph_batch: List[List[List[int]]], | |
knn_batch: List[List[List[int]]], sorted_dist_ind_batch: List[ndarray] | |
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
"""Generate graph convolution network input data. | |
Args: | |
node_feat_batch (List[Tensor]): The batched graph node features. | |
node_label_batch (List[ndarray]): The batched text component | |
labels. | |
local_graph_batch (List[List[List[int]]]): The local graph node | |
indices of image batch. | |
knn_batch (List[List[List[int]]]): The knn graph node indices of | |
image batch. | |
sorted_dist_ind_batch (List[ndarray]): The node indices sorted | |
according to the Euclidean distance. | |
Returns: | |
Tuple(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
gt_linkage): | |
- local_graphs_node_feat (Tensor): The node features of graph. | |
- adjacent_matrices (Tensor): The adjacent matrices of local | |
graphs. | |
- pivots_knn_inds (Tensor): The k-nearest neighbor indices in | |
local graph. | |
- gt_linkage (Tensor): The surpervision signal of GCN for linkage | |
prediction. | |
""" | |
assert isinstance(node_feat_batch, list) | |
assert isinstance(node_label_batch, list) | |
assert isinstance(local_graph_batch, list) | |
assert isinstance(knn_batch, list) | |
assert isinstance(sorted_dist_ind_batch, list) | |
num_max_nodes = max( | |
len(pivot_local_graph) for pivot_local_graphs in local_graph_batch | |
for pivot_local_graph in pivot_local_graphs) | |
local_graphs_node_feat = [] | |
adjacent_matrices = [] | |
pivots_knn_inds = [] | |
pivots_gt_linkage = [] | |
for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch): | |
node_feats = node_feat_batch[batch_ind] | |
pivot_local_graphs = local_graph_batch[batch_ind] | |
pivot_knns = knn_batch[batch_ind] | |
node_labels = node_label_batch[batch_ind] | |
device = node_feats.device | |
for graph_ind, pivot_knn in enumerate(pivot_knns): | |
pivot_local_graph = pivot_local_graphs[graph_ind] | |
num_nodes = len(pivot_local_graph) | |
pivot_ind = pivot_local_graph[0] | |
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)} | |
knn_inds = torch.tensor( | |
[node2ind_map[i] for i in pivot_knn[1:]]) | |
pivot_feats = node_feats[pivot_ind] | |
normalized_feats = node_feats[pivot_local_graph] - pivot_feats | |
adjacent_matrix = np.zeros((num_nodes, num_nodes), | |
dtype=np.float32) | |
for node in pivot_local_graph: | |
neighbors = sorted_dist_inds[node, | |
1:self.num_adjacent_linkages + | |
1] | |
for neighbor in neighbors: | |
if neighbor in pivot_local_graph: | |
adjacent_matrix[node2ind_map[node], | |
node2ind_map[neighbor]] = 1 | |
adjacent_matrix[node2ind_map[neighbor], | |
node2ind_map[node]] = 1 | |
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix) | |
pad_adjacent_matrix = torch.zeros( | |
(num_max_nodes, num_max_nodes), | |
dtype=torch.float, | |
device=device) | |
pad_adjacent_matrix[:num_nodes, :num_nodes] = torch.from_numpy( | |
adjacent_matrix) | |
pad_normalized_feats = torch.cat([ | |
normalized_feats, | |
torch.zeros( | |
(num_max_nodes - num_nodes, normalized_feats.shape[1]), | |
dtype=torch.float, | |
device=device) | |
], | |
dim=0) | |
local_graph_labels = node_labels[pivot_local_graph] | |
knn_labels = local_graph_labels[knn_inds] | |
link_labels = ((node_labels[pivot_ind] == knn_labels) & | |
(node_labels[pivot_ind] > 0)).astype(np.int64) | |
link_labels = torch.from_numpy(link_labels) | |
local_graphs_node_feat.append(pad_normalized_feats) | |
adjacent_matrices.append(pad_adjacent_matrix) | |
pivots_knn_inds.append(knn_inds) | |
pivots_gt_linkage.append(link_labels) | |
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0) | |
adjacent_matrices = torch.stack(adjacent_matrices, 0) | |
pivots_knn_inds = torch.stack(pivots_knn_inds, 0) | |
pivots_gt_linkage = torch.stack(pivots_gt_linkage, 0) | |
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
pivots_gt_linkage) | |
def __call__(self, feat_maps: Tensor, comp_attribs: ndarray | |
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
"""Generate local graphs as GCN input. | |
Args: | |
feat_maps (Tensor): The feature maps to extract the content | |
features of text components. | |
comp_attribs (ndarray): The text component attributes. | |
Returns: | |
Tuple(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
gt_linkage): | |
- local_graphs_node_feat (Tensor): The node features of graph. | |
- adjacent_matrices (Tensor): The adjacent matrices of local | |
graphs. | |
- pivots_knn_inds (Tensor): The k-nearest neighbor indices in local | |
graph. | |
- gt_linkage (Tensor): The surpervision signal of GCN for linkage | |
prediction. | |
""" | |
assert isinstance(feat_maps, Tensor) | |
assert comp_attribs.ndim == 3 | |
assert comp_attribs.shape[2] == 8 | |
sorted_dist_inds_batch = [] | |
local_graph_batch = [] | |
knn_batch = [] | |
node_feat_batch = [] | |
node_label_batch = [] | |
device = feat_maps.device | |
for batch_ind in range(comp_attribs.shape[0]): | |
num_comps = int(comp_attribs[batch_ind, 0, 0]) | |
comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7] | |
node_labels = comp_attribs[batch_ind, :num_comps, | |
7].astype(np.int32) | |
comp_centers = comp_geo_attribs[:, 0:2] | |
distance_matrix = euclidean_distance_matrix( | |
comp_centers, comp_centers) | |
batch_id = np.zeros( | |
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind | |
comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1) | |
angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign( | |
comp_geo_attribs[:, -1]) | |
angle = angle.reshape((-1, 1)) | |
rotated_rois = np.hstack( | |
[batch_id, comp_geo_attribs[:, :-2], angle]) | |
rois = torch.from_numpy(rotated_rois).to(device) | |
content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0), | |
rois) | |
content_feats = content_feats.view(content_feats.shape[0], | |
-1).to(feat_maps.device) | |
geo_feats = feature_embedding(comp_geo_attribs, | |
self.node_geo_feat_dim) | |
geo_feats = torch.from_numpy(geo_feats).to(device) | |
node_feats = torch.cat([content_feats, geo_feats], dim=-1) | |
sorted_dist_inds = np.argsort(distance_matrix, axis=1) | |
pivot_local_graphs, pivot_knns = self.generate_local_graphs( | |
sorted_dist_inds, node_labels) | |
node_feat_batch.append(node_feats) | |
node_label_batch.append(node_labels) | |
local_graph_batch.append(pivot_local_graphs) | |
knn_batch.append(pivot_knns) | |
sorted_dist_inds_batch.append(sorted_dist_inds) | |
(node_feats, adjacent_matrices, knn_inds, gt_linkage) = \ | |
self.generate_gcn_input(node_feat_batch, | |
node_label_batch, | |
local_graph_batch, | |
knn_batch, | |
sorted_dist_inds_batch) | |
return node_feats, adjacent_matrices, knn_inds, gt_linkage | |
class ProposalLocalGraphs: | |
"""Propose text components and generate local graphs for GCN to classify | |
the k-nearest neighbors of a pivot in `DRRG: Deep Relational Reasoning | |
Graph Network for Arbitrary Shape Text Detection. | |
<https://arxiv.org/abs/2003.07493>`_. | |
This code was partially adapted from https://github.com/GXYM/DRRG licensed | |
under the MIT license. | |
Args: | |
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. | |
num_adjacent_linkages (int): The number of linkages when constructing | |
adjacent matrix. | |
node_geo_feat_len (int): The length of embedded geometric feature | |
vector of a text component. | |
pooling_scale (float): The spatial scale of rotated RoI-Align. | |
pooling_output_size (tuple(int)): The output size of rotated RoI-Align. | |
nms_thr (float): The locality-aware NMS threshold for text components. | |
min_width (float): The minimum width of text components. | |
max_width (float): The maximum width of text components. | |
comp_shrink_ratio (float): The shrink ratio of text components. | |
comp_w_h_ratio (float): The width to height ratio of text components. | |
comp_score_thr (float): The score threshold of text component. | |
text_region_thr (float): The threshold for text region probability map. | |
center_region_thr (float): The threshold for text center region | |
probability map. | |
center_region_area_thr (int): The threshold for filtering small-sized | |
text center region. | |
""" | |
def __init__(self, k_at_hops: Tuple[int, int], num_adjacent_linkages: int, | |
node_geo_feat_len: int, pooling_scale: float, | |
pooling_output_size: Sequence[int], nms_thr: float, | |
min_width: float, max_width: float, comp_shrink_ratio: float, | |
comp_w_h_ratio: float, comp_score_thr: float, | |
text_region_thr: float, center_region_thr: float, | |
center_region_area_thr: int) -> None: | |
assert len(k_at_hops) == 2 | |
assert isinstance(k_at_hops, tuple) | |
assert isinstance(num_adjacent_linkages, int) | |
assert isinstance(node_geo_feat_len, int) | |
assert isinstance(pooling_scale, float) | |
assert isinstance(pooling_output_size, tuple) | |
assert isinstance(nms_thr, float) | |
assert isinstance(min_width, float) | |
assert isinstance(max_width, float) | |
assert isinstance(comp_shrink_ratio, float) | |
assert isinstance(comp_w_h_ratio, float) | |
assert isinstance(comp_score_thr, float) | |
assert isinstance(text_region_thr, float) | |
assert isinstance(center_region_thr, float) | |
assert isinstance(center_region_area_thr, int) | |
self.k_at_hops = k_at_hops | |
self.active_connection = num_adjacent_linkages | |
self.local_graph_depth = len(self.k_at_hops) | |
self.node_geo_feat_dim = node_geo_feat_len | |
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale) | |
self.nms_thr = nms_thr | |
self.min_width = min_width | |
self.max_width = max_width | |
self.comp_shrink_ratio = comp_shrink_ratio | |
self.comp_w_h_ratio = comp_w_h_ratio | |
self.comp_score_thr = comp_score_thr | |
self.text_region_thr = text_region_thr | |
self.center_region_thr = center_region_thr | |
self.center_region_area_thr = center_region_area_thr | |
def propose_comps(self, score_map: ndarray, top_height_map: ndarray, | |
bot_height_map: ndarray, sin_map: ndarray, | |
cos_map: ndarray, comp_score_thr: float, | |
min_width: float, max_width: float, | |
comp_shrink_ratio: float, | |
comp_w_h_ratio: float) -> ndarray: | |
"""Propose text components. | |
Args: | |
score_map (ndarray): The score map for NMS. | |
top_height_map (ndarray): The predicted text height map from each | |
pixel in text center region to top sideline. | |
bot_height_map (ndarray): The predicted text height map from each | |
pixel in text center region to bottom sideline. | |
sin_map (ndarray): The predicted sin(theta) map. | |
cos_map (ndarray): The predicted cos(theta) map. | |
comp_score_thr (float): The score threshold of text component. | |
min_width (float): The minimum width of text components. | |
max_width (float): The maximum width of text components. | |
comp_shrink_ratio (float): The shrink ratio of text components. | |
comp_w_h_ratio (float): The width to height ratio of text | |
components. | |
Returns: | |
ndarray: The text components. | |
""" | |
comp_centers = np.argwhere(score_map > comp_score_thr) | |
comp_centers = comp_centers[np.argsort(comp_centers[:, 0])] | |
y = comp_centers[:, 0] | |
x = comp_centers[:, 1] | |
top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio | |
bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio | |
sin = sin_map[y, x].reshape((-1, 1)) | |
cos = cos_map[y, x].reshape((-1, 1)) | |
top_mid_pts = comp_centers + np.hstack( | |
[top_height * sin, top_height * cos]) | |
bot_mid_pts = comp_centers - np.hstack( | |
[bot_height * sin, bot_height * cos]) | |
width = (top_height + bot_height) * comp_w_h_ratio | |
width = np.clip(width, min_width, max_width) | |
r = width / 2 | |
tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos]) | |
tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos]) | |
br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos]) | |
bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos]) | |
text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32) | |
score = score_map[y, x].reshape((-1, 1)) | |
text_comps = np.hstack([text_comps, score]) | |
return text_comps | |
def propose_comps_and_attribs(self, text_region_map: ndarray, | |
center_region_map: ndarray, | |
top_height_map: ndarray, | |
bot_height_map: ndarray, sin_map: ndarray, | |
cos_map: ndarray) -> Tuple[ndarray, ndarray]: | |
"""Generate text components and attributes. | |
Args: | |
text_region_map (ndarray): The predicted text region probability | |
map. | |
center_region_map (ndarray): The predicted text center region | |
probability map. | |
top_height_map (ndarray): The predicted text height map from each | |
pixel in text center region to top sideline. | |
bot_height_map (ndarray): The predicted text height map from each | |
pixel in text center region to bottom sideline. | |
sin_map (ndarray): The predicted sin(theta) map. | |
cos_map (ndarray): The predicted cos(theta) map. | |
Returns: | |
tuple(ndarray, ndarray): | |
- comp_attribs (ndarray): The text component attributes. | |
- text_comps (ndarray): The text components. | |
""" | |
assert (text_region_map.shape == center_region_map.shape == | |
top_height_map.shape == bot_height_map.shape == sin_map.shape | |
== cos_map.shape) | |
text_mask = text_region_map > self.text_region_thr | |
center_region_mask = (center_region_map > | |
self.center_region_thr) * text_mask | |
scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8)) | |
sin_map, cos_map = sin_map * scale, cos_map * scale | |
center_region_mask = fill_hole(center_region_mask) | |
center_region_contours, _ = cv2.findContours( | |
center_region_mask.astype(np.uint8), cv2.RETR_TREE, | |
cv2.CHAIN_APPROX_SIMPLE) | |
mask_sz = center_region_map.shape | |
comp_list = [] | |
for contour in center_region_contours: | |
current_center_mask = np.zeros(mask_sz) | |
cv2.drawContours(current_center_mask, [contour], -1, 1, -1) | |
if current_center_mask.sum() <= self.center_region_area_thr: | |
continue | |
score_map = text_region_map * current_center_mask | |
text_comps = self.propose_comps(score_map, top_height_map, | |
bot_height_map, sin_map, cos_map, | |
self.comp_score_thr, | |
self.min_width, self.max_width, | |
self.comp_shrink_ratio, | |
self.comp_w_h_ratio) | |
if la_nms is None: | |
raise ImportError('lanms-neo is not installed, ' | |
'please run "pip install lanms-neo==1.0.2".') | |
text_comps = la_nms(text_comps, self.nms_thr) | |
text_comp_mask = np.zeros(mask_sz) | |
text_comp_boxes = text_comps[:, :8].reshape( | |
(-1, 4, 2)).astype(np.int32) | |
cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1) | |
if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5: | |
continue | |
if text_comps.shape[-1] > 0: | |
comp_list.append(text_comps) | |
if len(comp_list) <= 0: | |
return None, None | |
text_comps = np.vstack(comp_list) | |
text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2)) | |
centers = np.mean(text_comp_boxes, axis=1).astype(np.int32) | |
x = centers[:, 0] | |
y = centers[:, 1] | |
scores = [] | |
for text_comp_box in text_comp_boxes: | |
text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0, | |
mask_sz[1] - 1) | |
text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0, | |
mask_sz[0] - 1) | |
min_coord = np.min(text_comp_box, axis=0).astype(np.int32) | |
max_coord = np.max(text_comp_box, axis=0).astype(np.int32) | |
text_comp_box = text_comp_box - min_coord | |
box_sz = (max_coord - min_coord + 1) | |
temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8) | |
cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1) | |
temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] + | |
1), | |
min_coord[0]:(max_coord[0] + | |
1)] | |
score = cv2.mean(temp_region_patch, temp_comp_mask)[0] | |
scores.append(score) | |
scores = np.array(scores).reshape((-1, 1)) | |
text_comps = np.hstack([text_comps[:, :-1], scores]) | |
h = top_height_map[y, x].reshape( | |
(-1, 1)) + bot_height_map[y, x].reshape((-1, 1)) | |
w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width) | |
sin = sin_map[y, x].reshape((-1, 1)) | |
cos = cos_map[y, x].reshape((-1, 1)) | |
x = x.reshape((-1, 1)) | |
y = y.reshape((-1, 1)) | |
comp_attribs = np.hstack([x, y, h, w, cos, sin]) | |
return comp_attribs, text_comps | |
def generate_local_graphs(self, sorted_dist_inds: ndarray, | |
node_feats: Tensor | |
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
"""Generate local graphs and graph convolution network input data. | |
Args: | |
sorted_dist_inds (ndarray): The node indices sorted according to | |
the Euclidean distance. | |
node_feats (tensor): The features of nodes in graph. | |
Returns: | |
Tuple(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
pivots_local_graphs): | |
- local_graphs_node_feats (tensor): The features of nodes in local | |
graphs. | |
- adjacent_matrices (tensor): The adjacent matrices. | |
- pivots_knn_inds (tensor): The k-nearest neighbor indices in | |
local graphs. | |
- pivots_local_graphs (tensor): The indices of nodes in local | |
graphs. | |
""" | |
assert sorted_dist_inds.ndim == 2 | |
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] == | |
node_feats.shape[0]) | |
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1] | |
pivot_local_graphs = [] | |
pivot_knns = [] | |
device = node_feats.device | |
for pivot_ind, knn in enumerate(knn_graph): | |
local_graph_neighbors = set(knn) | |
for neighbor_ind in knn: | |
local_graph_neighbors.update( | |
set(sorted_dist_inds[neighbor_ind, | |
1:self.k_at_hops[1] + 1])) | |
local_graph_neighbors.discard(pivot_ind) | |
pivot_local_graph = list(local_graph_neighbors) | |
pivot_local_graph.insert(0, pivot_ind) | |
pivot_knn = [pivot_ind] + list(knn) | |
pivot_local_graphs.append(pivot_local_graph) | |
pivot_knns.append(pivot_knn) | |
num_max_nodes = max( | |
len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs) | |
local_graphs_node_feat = [] | |
adjacent_matrices = [] | |
pivots_knn_inds = [] | |
pivots_local_graphs = [] | |
for graph_ind, pivot_knn in enumerate(pivot_knns): | |
pivot_local_graph = pivot_local_graphs[graph_ind] | |
num_nodes = len(pivot_local_graph) | |
pivot_ind = pivot_local_graph[0] | |
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)} | |
knn_inds = torch.tensor([node2ind_map[i] | |
for i in pivot_knn[1:]]).long().to(device) | |
pivot_feats = node_feats[pivot_ind] | |
normalized_feats = node_feats[pivot_local_graph] - pivot_feats | |
adjacent_matrix = np.zeros((num_nodes, num_nodes)) | |
for node in pivot_local_graph: | |
neighbors = sorted_dist_inds[node, | |
1:self.active_connection + 1] | |
for neighbor in neighbors: | |
if neighbor in pivot_local_graph: | |
adjacent_matrix[node2ind_map[node], | |
node2ind_map[neighbor]] = 1 | |
adjacent_matrix[node2ind_map[neighbor], | |
node2ind_map[node]] = 1 | |
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix) | |
pad_adjacent_matrix = torch.zeros((num_max_nodes, num_max_nodes), | |
dtype=torch.float, | |
device=device) | |
pad_adjacent_matrix[:num_nodes, :num_nodes] = torch.from_numpy( | |
adjacent_matrix) | |
pad_normalized_feats = torch.cat([ | |
normalized_feats, | |
torch.zeros( | |
(num_max_nodes - num_nodes, normalized_feats.shape[1]), | |
dtype=torch.float, | |
device=device) | |
], | |
dim=0) | |
local_graph_nodes = torch.tensor(pivot_local_graph) | |
local_graph_nodes = torch.cat([ | |
local_graph_nodes, | |
torch.zeros(num_max_nodes - num_nodes, dtype=torch.long) | |
], | |
dim=-1) | |
local_graphs_node_feat.append(pad_normalized_feats) | |
adjacent_matrices.append(pad_adjacent_matrix) | |
pivots_knn_inds.append(knn_inds) | |
pivots_local_graphs.append(local_graph_nodes) | |
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0) | |
adjacent_matrices = torch.stack(adjacent_matrices, 0) | |
pivots_knn_inds = torch.stack(pivots_knn_inds, 0) | |
pivots_local_graphs = torch.stack(pivots_local_graphs, 0) | |
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
pivots_local_graphs) | |
def __call__(self, preds: Tensor, feat_maps: Tensor | |
) -> Tuple[bool, Tensor, Tensor, Tensor, Tensor, ndarray]: | |
"""Generate local graphs and graph convolutional network input data. | |
Args: | |
preds (tensor): The predicted maps. | |
feat_maps (tensor): The feature maps to extract content feature of | |
text components. | |
Returns: | |
Tuple(none_flag, local_graphs_node_feat, adjacent_matrices, | |
pivots_knn_inds, pivots_local_graphs, text_comps): | |
- none_flag (bool): The flag showing whether the number of proposed | |
text components is 0. | |
- local_graphs_node_feats (tensor): The features of nodes in local | |
graphs. | |
- adjacent_matrices (tensor): The adjacent matrices. | |
- pivots_knn_inds (tensor): The k-nearest neighbor indices in | |
local graphs. | |
- pivots_local_graphs (tensor): The indices of nodes in local | |
graphs. | |
- text_comps (ndarray): The predicted text components. | |
""" | |
if preds.ndim == 4: | |
assert preds.shape[0] == 1 | |
preds = torch.squeeze(preds) | |
pred_text_region = torch.sigmoid(preds[0]).data.cpu().numpy() | |
pred_center_region = torch.sigmoid(preds[1]).data.cpu().numpy() | |
pred_sin_map = preds[2].data.cpu().numpy() | |
pred_cos_map = preds[3].data.cpu().numpy() | |
pred_top_height_map = preds[4].data.cpu().numpy() | |
pred_bot_height_map = preds[5].data.cpu().numpy() | |
device = preds.device | |
comp_attribs, text_comps = self.propose_comps_and_attribs( | |
pred_text_region, pred_center_region, pred_top_height_map, | |
pred_bot_height_map, pred_sin_map, pred_cos_map) | |
if comp_attribs is None or len(comp_attribs) < 2: | |
none_flag = True | |
return none_flag, (0, 0, 0, 0, 0) | |
comp_centers = comp_attribs[:, 0:2] | |
distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers) | |
geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim) | |
geo_feats = torch.from_numpy(geo_feats).to(preds.device) | |
batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32) | |
comp_attribs = comp_attribs.astype(np.float32) | |
angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1]) | |
angle = angle.reshape((-1, 1)) | |
rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle]) | |
rois = torch.from_numpy(rotated_rois).to(device) | |
content_feats = self.pooling(feat_maps, rois) | |
content_feats = content_feats.view(content_feats.shape[0], | |
-1).to(device) | |
node_feats = torch.cat([content_feats, geo_feats], dim=-1) | |
sorted_dist_inds = np.argsort(distance_matrix, axis=1) | |
(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
pivots_local_graphs) = self.generate_local_graphs( | |
sorted_dist_inds, node_feats) | |
none_flag = False | |
return none_flag, (local_graphs_node_feat, adjacent_matrices, | |
pivots_knn_inds, pivots_local_graphs, text_comps) | |
class GraphConv(BaseModule): | |
"""Graph convolutional neural network. | |
Args: | |
in_dim (int): The number of input channels. | |
out_dim (int): The number of output channels. | |
""" | |
class MeanAggregator(BaseModule): | |
"""Mean aggregator for graph convolutional network.""" | |
def forward(self, features: Tensor, A: Tensor) -> Tensor: | |
"""Forward function.""" | |
x = torch.bmm(A, features) | |
return x | |
def __init__(self, in_dim: int, out_dim: int) -> None: | |
super().__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim)) | |
self.bias = nn.Parameter(torch.FloatTensor(out_dim)) | |
init.xavier_uniform_(self.weight) | |
init.constant_(self.bias, 0) | |
self.aggregator = self.MeanAggregator() | |
def forward(self, features: Tensor, A: Tensor) -> Tensor: | |
"""Forward function.""" | |
_, _, d = features.shape | |
assert d == self.in_dim | |
agg_feats = self.aggregator(features, A) | |
cat_feats = torch.cat([features, agg_feats], dim=2) | |
out = torch.einsum('bnd,df->bnf', cat_feats, self.weight) | |
out = F.relu(out + self.bias) | |
return out | |
class GCN(BaseModule): | |
"""Graph convolutional network for clustering. This was from repo | |
https://github.com/Zhongdao/gcn_clustering licensed under the MIT license. | |
Args: | |
feat_len (int): The input node feature length. | |
init_cfg (dict or list[dict], optional): Initialization configs. | |
""" | |
def __init__(self, | |
feat_len: int, | |
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.bn0 = nn.BatchNorm1d(feat_len, affine=False).float() | |
self.conv1 = GraphConv(feat_len, 512) | |
self.conv2 = GraphConv(512, 256) | |
self.conv3 = GraphConv(256, 128) | |
self.conv4 = GraphConv(128, 64) | |
self.classifier = nn.Sequential( | |
nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2)) | |
def forward(self, node_feats: Tensor, adj_mats: Tensor, | |
knn_inds: Tensor) -> Tensor: | |
"""Forward function. | |
Args: | |
local_graphs_node_feat (Tensor): The node features of graph. | |
adjacent_matrices (Tensor): The adjacent matrices of local | |
graphs. | |
pivots_knn_inds (Tensor): The k-nearest neighbor indices in | |
local graph. | |
Returns: | |
Tensor: The output feature. | |
""" | |
num_local_graphs, num_max_nodes, feat_len = node_feats.shape | |
node_feats = node_feats.view(-1, feat_len) | |
node_feats = self.bn0(node_feats) | |
node_feats = node_feats.view(num_local_graphs, num_max_nodes, feat_len) | |
node_feats = self.conv1(node_feats, adj_mats) | |
node_feats = self.conv2(node_feats, adj_mats) | |
node_feats = self.conv3(node_feats, adj_mats) | |
node_feats = self.conv4(node_feats, adj_mats) | |
k = knn_inds.size(-1) | |
mid_feat_len = node_feats.size(-1) | |
edge_feat = torch.zeros((num_local_graphs, k, mid_feat_len), | |
device=node_feats.device) | |
for graph_ind in range(num_local_graphs): | |
edge_feat[graph_ind, :, :] = node_feats[graph_ind, | |
knn_inds[graph_ind]] | |
edge_feat = edge_feat.view(-1, mid_feat_len) | |
pred = self.classifier(edge_feat) | |
return pred | |