Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textdet.losses.pse_loss

# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core import BitmapMasks

from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument
from . import PANLoss


[docs]@LOSSES.register_module() class PSELoss(PANLoss): r"""The class for implementing PSENet loss. This is partially adapted from https://github.com/whai362/PSENet. PSENet: `Shape Robust Text Detection with Progressive Scale Expansion Network <https://arxiv.org/abs/1806.02559>`_. Args: alpha (float): Text loss coefficient, and :math:`1-\alpha` is the kernel loss coefficient. ohem_ratio (float): The negative/positive ratio in ohem. reduction (str): The way to reduce the loss. Available options are "mean" and "sum". """ def __init__(self, alpha=0.7, ohem_ratio=3, reduction='mean', kernel_sample_type='adaptive'): super().__init__() assert reduction in ['mean', 'sum' ], "reduction must be either of ['mean','sum']" self.alpha = alpha self.ohem_ratio = ohem_ratio self.reduction = reduction self.kernel_sample_type = kernel_sample_type
[docs] def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask): """Compute PSENet loss. Args: score_maps (tensor): The output tensor with size of Nx6xHxW. downsample_ratio (float): The downsample ratio between score_maps and the input img. gt_kernels (list[BitmapMasks]): The kernel list with each element being the text kernel mask for one img. gt_mask (list[BitmapMasks]): The effective mask list with each element being the effective mask for one img. Returns: dict: A loss dict with ``loss_text`` and ``loss_kernel``. """ assert check_argument.is_type_list(gt_kernels, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert isinstance(downsample_ratio, float) losses = [] pred_texts = score_maps[:, 0, :, :] pred_kernels = score_maps[:, 1:, :, :] feature_sz = score_maps.size() gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels] gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:]) gt_kernels = [item.to(score_maps.device) for item in gt_kernels] gt_mask = [item.rescale(downsample_ratio) for item in gt_mask] gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:]) gt_mask = [item.to(score_maps.device) for item in gt_mask] # compute text loss sampled_masks_text = self.ohem_batch(pred_texts.detach(), gt_kernels[0], gt_mask[0]) loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0], sampled_masks_text) losses.append(self.alpha * loss_texts) # compute kernel loss if self.kernel_sample_type == 'hard': sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * ( gt_mask[0].float()) elif self.kernel_sample_type == 'adaptive': sampled_masks_kernel = (pred_texts > 0).float() * ( gt_mask[0].float()) else: raise NotImplementedError num_kernel = pred_kernels.shape[1] assert num_kernel == len(gt_kernels) - 1 loss_list = [] for idx in range(num_kernel): loss_kernels = self.dice_loss_with_logits( pred_kernels[:, idx, :, :], gt_kernels[1 + idx], sampled_masks_kernel) loss_list.append(loss_kernels) losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list)) if self.reduction == 'mean': losses = [item.mean() for item in losses] elif self.reduction == 'sum': losses = [item.sum() for item in losses] else: raise NotImplementedError results = dict(loss_text=losses[0], loss_kernel=losses[1]) return results
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.