Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Dict, Sequence, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from mmdet.models.utils import multi_apply | |
from torch import nn | |
from mmocr.registry import MODELS | |
from mmocr.structures import TextDetDataSample | |
from .seg_based_module_loss import SegBasedModuleLoss | |
class PANModuleLoss(SegBasedModuleLoss): | |
"""The class for implementing PANet loss. This was partially adapted from | |
https://github.com/whai362/pan_pp.pytorch and | |
https://github.com/WenmuZhou/PAN.pytorch. | |
PANet: `Efficient and Accurate Arbitrary- | |
Shaped Text Detection with Pixel Aggregation Network | |
<https://arxiv.org/abs/1908.05900>`_. | |
Args: | |
loss_text (dict) The loss config for text map. Defaults to | |
dict(type='MaskedSquareDiceLoss'). | |
loss_kernel (dict) The loss config for kernel map. Defaults to | |
dict(type='MaskedSquareDiceLoss'). | |
loss_embedding (dict) The loss config for embedding map. Defaults to | |
dict(type='PANEmbLossV1'). | |
weight_text (float): The weight of text loss. Defaults to 1. | |
weight_kernel (float): The weight of kernel loss. Defaults to 0.5. | |
weight_embedding (float): The weight of embedding loss. | |
Defaults to 0.25. | |
ohem_ratio (float): The negative/positive ratio in ohem. Defaults to 3. | |
shrink_ratio (tuple[float]) : The ratio of shrinking kernel. Defaults | |
to (1.0, 0.5). | |
max_shrink_dist (int or float): The maximum shrinking distance. | |
Defaults to 20. | |
reduction (str): The way to reduce the loss. Available options are | |
"mean" and "sum". Defaults to 'mean'. | |
""" | |
def __init__( | |
self, | |
loss_text: Dict = dict(type='MaskedSquareDiceLoss'), | |
loss_kernel: Dict = dict(type='MaskedSquareDiceLoss'), | |
loss_embedding: Dict = dict(type='PANEmbLossV1'), | |
weight_text: float = 1.0, | |
weight_kernel: float = 0.5, | |
weight_embedding: float = 0.25, | |
ohem_ratio: Union[int, float] = 3, # TODO Find a better name | |
shrink_ratio: Sequence[Union[int, float]] = (1.0, 0.5), | |
max_shrink_dist: Union[int, float] = 20, | |
reduction: str = 'mean') -> None: | |
super().__init__() | |
assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']" | |
self.weight_text = weight_text | |
self.weight_kernel = weight_kernel | |
self.weight_embedding = weight_embedding | |
self.shrink_ratio = shrink_ratio | |
self.ohem_ratio = ohem_ratio | |
self.reduction = reduction | |
self.max_shrink_dist = max_shrink_dist | |
self.loss_text = MODELS.build(loss_text) | |
self.loss_kernel = MODELS.build(loss_kernel) | |
self.loss_embedding = MODELS.build(loss_embedding) | |
def forward(self, preds: torch.Tensor, | |
data_samples: Sequence[TextDetDataSample]) -> Dict: | |
"""Compute PAN loss. | |
Args: | |
preds (dict): Raw predictions from model with | |
shape :math:`(N, C, H, W)`. | |
data_samples (list[TextDetDataSample]): The data samples. | |
Returns: | |
dict: The dict for pan losses with loss_text, loss_kernel, | |
loss_aggregation and loss_discrimination. | |
""" | |
gt_kernels, gt_masks = self.get_targets(data_samples) | |
target_size = gt_kernels.size()[2:] | |
preds = F.interpolate(preds, size=target_size, mode='bilinear') | |
pred_texts = preds[:, 0, :, :] | |
pred_kernels = preds[:, 1, :, :] | |
inst_embed = preds[:, 2:, :, :] | |
gt_kernels = gt_kernels.to(preds.device) | |
gt_masks = gt_masks.to(preds.device) | |
# compute embedding loss | |
loss_emb = self.loss_embedding(inst_embed, gt_kernels[0], | |
gt_kernels[1], gt_masks) | |
gt_kernels[gt_kernels <= 0.5] = 0 | |
gt_kernels[gt_kernels > 0.5] = 1 | |
# compute text loss | |
sampled_mask = self._ohem_batch(pred_texts.detach(), gt_kernels[0], | |
gt_masks) | |
pred_texts = torch.sigmoid(pred_texts) | |
loss_texts = self.loss_text(pred_texts, gt_kernels[0], sampled_mask) | |
# compute kernel loss | |
pred_kernels = torch.sigmoid(pred_kernels) | |
sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * gt_masks | |
loss_kernels = self.loss_kernel(pred_kernels, gt_kernels[1], | |
sampled_masks_kernel) | |
losses = [loss_texts, loss_kernels, loss_emb] | |
if self.reduction == 'mean': | |
losses = [item.mean() for item in losses] | |
else: | |
losses = [item.sum() for item in losses] | |
results = dict() | |
results.update( | |
loss_text=self.weight_text * losses[0], | |
loss_kernel=self.weight_kernel * losses[1], | |
loss_embedding=self.weight_embedding * losses[2]) | |
return results | |
def get_targets( | |
self, | |
data_samples: Sequence[TextDetDataSample], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Generate the gt targets for PANet. | |
Args: | |
results (dict): The input result dictionary. | |
Returns: | |
results (dict): The output result dictionary. | |
""" | |
gt_kernels, gt_masks = multi_apply(self._get_target_single, | |
data_samples) | |
# gt_kernels: (N, kernel_number, H, W)->(kernel_number, N, H, W) | |
gt_kernels = torch.stack(gt_kernels, dim=0).permute(1, 0, 2, 3) | |
gt_masks = torch.stack(gt_masks, dim=0) | |
return gt_kernels, gt_masks | |
def _get_target_single(self, data_sample: TextDetDataSample | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Generate loss target from a data sample. | |
Args: | |
data_sample (TextDetDataSample): The data sample. | |
Returns: | |
tuple: A tuple of four tensors as the targets of one prediction. | |
""" | |
gt_polygons = data_sample.gt_instances.polygons | |
gt_ignored = data_sample.gt_instances.ignored | |
gt_kernels = [] | |
for ratio in self.shrink_ratio: | |
# TODO pass `gt_ignored` to `_generate_kernels` | |
gt_kernel, _ = self._generate_kernels( | |
data_sample.img_shape, | |
gt_polygons, | |
ratio, | |
ignore_flags=None, | |
max_shrink_dist=self.max_shrink_dist) | |
gt_kernels.append(gt_kernel) | |
gt_polygons_ignored = data_sample.gt_instances[gt_ignored].polygons | |
gt_mask = self._generate_effective_mask(data_sample.img_shape, | |
gt_polygons_ignored) | |
gt_kernels = np.stack(gt_kernels, axis=0) | |
gt_kernels = torch.from_numpy(gt_kernels).float() | |
gt_mask = torch.from_numpy(gt_mask).float() | |
return gt_kernels, gt_mask | |
def _ohem_batch(self, text_scores: torch.Tensor, gt_texts: torch.Tensor, | |
gt_mask: torch.Tensor) -> torch.Tensor: | |
"""OHEM sampling for a batch of imgs. | |
Args: | |
text_scores (Tensor): The text scores of size :math:`(H, W)`. | |
gt_texts (Tensor): The gt text masks of size :math:`(H, W)`. | |
gt_mask (Tensor): The gt effective mask of size :math:`(H, W)`. | |
Returns: | |
Tensor: The sampled mask of size :math:`(H, W)`. | |
""" | |
assert isinstance(text_scores, torch.Tensor) | |
assert isinstance(gt_texts, torch.Tensor) | |
assert isinstance(gt_mask, torch.Tensor) | |
assert len(text_scores.shape) == 3 | |
assert text_scores.shape == gt_texts.shape | |
assert gt_texts.shape == gt_mask.shape | |
sampled_masks = [] | |
for i in range(text_scores.shape[0]): | |
sampled_masks.append( | |
self._ohem_single(text_scores[i], gt_texts[i], gt_mask[i])) | |
sampled_masks = torch.stack(sampled_masks) | |
return sampled_masks | |
def _ohem_single(self, text_score: torch.Tensor, gt_text: torch.Tensor, | |
gt_mask: torch.Tensor) -> torch.Tensor: | |
"""Sample the top-k maximal negative samples and all positive samples. | |
Args: | |
text_score (Tensor): The text score of size :math:`(H, W)`. | |
gt_text (Tensor): The ground truth text mask of size | |
:math:`(H, W)`. | |
gt_mask (Tensor): The effective region mask of size :math:`(H, W)`. | |
Returns: | |
Tensor: The sampled pixel mask of size :math:`(H, W)`. | |
""" | |
assert isinstance(text_score, torch.Tensor) | |
assert isinstance(gt_text, torch.Tensor) | |
assert isinstance(gt_mask, torch.Tensor) | |
assert len(text_score.shape) == 2 | |
assert text_score.shape == gt_text.shape | |
assert gt_text.shape == gt_mask.shape | |
pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)( | |
torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item()) | |
neg_num = (int)(torch.sum(gt_text <= 0.5).item()) | |
neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) | |
if pos_num == 0 or neg_num == 0: | |
warnings.warn('pos_num = 0 or neg_num = 0') | |
return gt_mask.bool() | |
neg_score = text_score[gt_text <= 0.5] | |
neg_score_sorted, _ = torch.sort(neg_score, descending=True) | |
threshold = neg_score_sorted[neg_num - 1] | |
sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * ( | |
gt_mask > 0.5) | |
return sampled_mask | |
class PANEmbLossV1(nn.Module): | |
"""The class for implementing EmbLossV1. This was partially adapted from | |
https://github.com/whai362/pan_pp.pytorch. | |
Args: | |
feature_dim (int): The dimension of the feature. Defaults to 4. | |
delta_aggregation (float): The delta for aggregation. Defaults to 0.5. | |
delta_discrimination (float): The delta for discrimination. | |
Defaults to 1.5. | |
""" | |
def __init__(self, | |
feature_dim: int = 4, | |
delta_aggregation: float = 0.5, | |
delta_discrimination: float = 1.5) -> None: | |
super().__init__() | |
self.feature_dim = feature_dim | |
self.delta_aggregation = delta_aggregation | |
self.delta_discrimination = delta_discrimination | |
self.weights = (1.0, 1.0) | |
def _forward_single(self, emb: torch.Tensor, instance: torch.Tensor, | |
kernel: torch.Tensor, | |
training_mask: torch.Tensor) -> torch.Tensor: | |
"""Compute the loss for a single image. | |
Args: | |
emb (torch.Tensor): The embedding feature. | |
instance (torch.Tensor): The instance feature. | |
kernel (torch.Tensor): The kernel feature. | |
training_mask (torch.Tensor): The effective mask. | |
""" | |
training_mask = (training_mask > 0.5).float() | |
kernel = (kernel > 0.5).float() | |
instance = instance * training_mask | |
instance_kernel = (instance * kernel).view(-1) | |
instance = instance.view(-1) | |
emb = emb.view(self.feature_dim, -1) | |
unique_labels, unique_ids = torch.unique( | |
instance_kernel, sorted=True, return_inverse=True) | |
num_instance = unique_labels.size(0) | |
if num_instance <= 1: | |
return 0 | |
emb_mean = emb.new_zeros((self.feature_dim, num_instance), | |
dtype=torch.float32) | |
for i, lb in enumerate(unique_labels): | |
if lb == 0: | |
continue | |
ind_k = instance_kernel == lb | |
emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1) | |
l_agg = emb.new_zeros(num_instance, dtype=torch.float32) | |
for i, lb in enumerate(unique_labels): | |
if lb == 0: | |
continue | |
ind = instance == lb | |
emb_ = emb[:, ind] | |
dist = (emb_ - emb_mean[:, i:i + 1]).norm(p=2, dim=0) | |
dist = F.relu(dist - self.delta_aggregation)**2 | |
l_agg[i] = torch.mean(torch.log(dist + 1.0)) | |
l_agg = torch.mean(l_agg[1:]) | |
if num_instance > 2: | |
emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1) | |
emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view( | |
-1, self.feature_dim) | |
mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view( | |
-1, 1).repeat(1, self.feature_dim) | |
mask = mask.view(num_instance, num_instance, -1) | |
mask[0, :, :] = 0 | |
mask[:, 0, :] = 0 | |
mask = mask.view(num_instance * num_instance, -1) | |
dist = emb_interleave - emb_band | |
dist = dist[mask > 0].view(-1, self.feature_dim).norm(p=2, dim=1) | |
dist = F.relu(2 * self.delta_discrimination - dist)**2 | |
l_dis = torch.mean(torch.log(dist + 1.0)) | |
else: | |
l_dis = 0 | |
l_agg = self.weights[0] * l_agg | |
l_dis = self.weights[1] * l_dis | |
l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001 | |
loss = l_agg + l_dis + l_reg | |
return loss | |
def forward(self, emb: torch.Tensor, instance: torch.Tensor, | |
kernel: torch.Tensor, | |
training_mask: torch.Tensor) -> torch.Tensor: | |
"""Compute the loss for a batch image. | |
Args: | |
emb (torch.Tensor): The embedding feature. | |
instance (torch.Tensor): The instance feature. | |
kernel (torch.Tensor): The kernel feature. | |
training_mask (torch.Tensor): The effective mask. | |
""" | |
loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32) | |
for i in range(loss_batch.size(0)): | |
loss_batch[i] = self._forward_single(emb[i], instance[i], | |
kernel[i], training_mask[i]) | |
return loss_batch | |