Shortcuts

Source code for mmocr.models.textdet.losses.pse_loss

# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import BitmapMasks

from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument
from . import PANLoss


[docs]@LOSSES.register_module() class PSELoss(PANLoss): r"""The class for implementing PSENet loss. This is partially adapted from https://github.com/whai362/PSENet. PSENet: `Shape Robust Text Detection with Progressive Scale Expansion Network <https://arxiv.org/abs/1806.02559>`_. Args: alpha (float): Text loss coefficient, and :math:`1-\alpha` is the kernel loss coefficient. ohem_ratio (float): The negative/positive ratio in ohem. reduction (str): The way to reduce the loss. Available options are "mean" and "sum". """ def __init__(self, alpha=0.7, ohem_ratio=3, reduction='mean', kernel_sample_type='adaptive'): super().__init__() assert reduction in ['mean', 'sum' ], "reduction must be either of ['mean','sum']" self.alpha = alpha self.ohem_ratio = ohem_ratio self.reduction = reduction self.kernel_sample_type = kernel_sample_type
[docs] def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask): """Compute PSENet loss. Args: score_maps (tensor): The output tensor with size of Nx6xHxW. downsample_ratio (float): The downsample ratio between score_maps and the input img. gt_kernels (list[BitmapMasks]): The kernel list with each element being the text kernel mask for one img. gt_mask (list[BitmapMasks]): The effective mask list with each element being the effective mask for one img. Returns: dict: A loss dict with ``loss_text`` and ``loss_kernel``. """ assert check_argument.is_type_list(gt_kernels, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert isinstance(downsample_ratio, float) losses = [] pred_texts = score_maps[:, 0, :, :] pred_kernels = score_maps[:, 1:, :, :] feature_sz = score_maps.size() gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels] gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:]) gt_kernels = [item.to(score_maps.device) for item in gt_kernels] gt_mask = [item.rescale(downsample_ratio) for item in gt_mask] gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:]) gt_mask = [item.to(score_maps.device) for item in gt_mask] # compute text loss sampled_masks_text = self.ohem_batch(pred_texts.detach(), gt_kernels[0], gt_mask[0]) loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0], sampled_masks_text) losses.append(self.alpha * loss_texts) # compute kernel loss if self.kernel_sample_type == 'hard': sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * ( gt_mask[0].float()) elif self.kernel_sample_type == 'adaptive': sampled_masks_kernel = (pred_texts > 0).float() * ( gt_mask[0].float()) else: raise NotImplementedError num_kernel = pred_kernels.shape[1] assert num_kernel == len(gt_kernels) - 1 loss_list = [] for idx in range(num_kernel): loss_kernels = self.dice_loss_with_logits( pred_kernels[:, idx, :, :], gt_kernels[1 + idx], sampled_masks_kernel) loss_list.append(loss_kernels) losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list)) if self.reduction == 'mean': losses = [item.mean() for item in losses] elif self.reduction == 'sum': losses = [item.sum() for item in losses] else: raise NotImplementedError results = dict(loss_text=losses[0], loss_kernel=losses[1]) return results
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.