Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textdet.losses.pse_loss
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import BitmapMasks
from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument
from . import PANLoss
[docs]@LOSSES.register_module()
class PSELoss(PANLoss):
r"""The class for implementing PSENet loss. This is partially adapted from
https://github.com/whai362/PSENet.
PSENet: `Shape Robust Text Detection with
Progressive Scale Expansion Network <https://arxiv.org/abs/1806.02559>`_.
Args:
alpha (float): Text loss coefficient, and :math:`1-\alpha` is the
kernel loss coefficient.
ohem_ratio (float): The negative/positive ratio in ohem.
reduction (str): The way to reduce the loss. Available options are
"mean" and "sum".
"""
def __init__(self,
alpha=0.7,
ohem_ratio=3,
reduction='mean',
kernel_sample_type='adaptive'):
super().__init__()
assert reduction in ['mean', 'sum'
], "reduction must be either of ['mean','sum']"
self.alpha = alpha
self.ohem_ratio = ohem_ratio
self.reduction = reduction
self.kernel_sample_type = kernel_sample_type
[docs] def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask):
"""Compute PSENet loss.
Args:
score_maps (tensor): The output tensor with size of Nx6xHxW.
downsample_ratio (float): The downsample ratio between score_maps
and the input img.
gt_kernels (list[BitmapMasks]): The kernel list with each element
being the text kernel mask for one img.
gt_mask (list[BitmapMasks]): The effective mask list
with each element being the effective mask for one img.
Returns:
dict: A loss dict with ``loss_text`` and ``loss_kernel``.
"""
assert check_argument.is_type_list(gt_kernels, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert isinstance(downsample_ratio, float)
losses = []
pred_texts = score_maps[:, 0, :, :]
pred_kernels = score_maps[:, 1:, :, :]
feature_sz = score_maps.size()
gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels]
gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:])
gt_kernels = [item.to(score_maps.device) for item in gt_kernels]
gt_mask = [item.rescale(downsample_ratio) for item in gt_mask]
gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:])
gt_mask = [item.to(score_maps.device) for item in gt_mask]
# compute text loss
sampled_masks_text = self.ohem_batch(pred_texts.detach(),
gt_kernels[0], gt_mask[0])
loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0],
sampled_masks_text)
losses.append(self.alpha * loss_texts)
# compute kernel loss
if self.kernel_sample_type == 'hard':
sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * (
gt_mask[0].float())
elif self.kernel_sample_type == 'adaptive':
sampled_masks_kernel = (pred_texts > 0).float() * (
gt_mask[0].float())
else:
raise NotImplementedError
num_kernel = pred_kernels.shape[1]
assert num_kernel == len(gt_kernels) - 1
loss_list = []
for idx in range(num_kernel):
loss_kernels = self.dice_loss_with_logits(
pred_kernels[:, idx, :, :], gt_kernels[1 + idx],
sampled_masks_kernel)
loss_list.append(loss_kernels)
losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list))
if self.reduction == 'mean':
losses = [item.mean() for item in losses]
elif self.reduction == 'sum':
losses = [item.sum() for item in losses]
else:
raise NotImplementedError
results = dict(loss_text=losses[0], loss_kernel=losses[1])
return results