Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textdet.losses.pan_loss

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import warnings

import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn

from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument


[docs]@LOSSES.register_module() class PANLoss(nn.Module): """The class for implementing PANet loss. This was partially adapted from https://github.com/WenmuZhou/PAN.pytorch. PANet: `Efficient and Accurate Arbitrary- Shaped Text Detection with Pixel Aggregation Network <https://arxiv.org/abs/1908.05900>`_. Args: alpha (float): The kernel loss coef. beta (float): The aggregation and discriminative loss coef. delta_aggregation (float): The constant for aggregation loss. delta_discrimination (float): The constant for discriminative loss. ohem_ratio (float): The negative/positive ratio in ohem. reduction (str): The way to reduce the loss. speedup_bbox_thr (int): Speed up if speedup_bbox_thr > 0 and < bbox num. """ def __init__(self, alpha=0.5, beta=0.25, delta_aggregation=0.5, delta_discrimination=3, ohem_ratio=3, reduction='mean', speedup_bbox_thr=-1): super().__init__() assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']" self.alpha = alpha self.beta = beta self.delta_aggregation = delta_aggregation self.delta_discrimination = delta_discrimination self.ohem_ratio = ohem_ratio self.reduction = reduction self.speedup_bbox_thr = speedup_bbox_thr
[docs] def bitmasks2tensor(self, bitmasks, target_sz): """Convert Bitmasks to tensor. Args: bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is for one img. target_sz (tuple(int, int)): The target tensor of size :math:`(H, W)`. Returns: list[Tensor]: The list of kernel tensors. Each element stands for one kernel level. """ assert check_argument.is_type_list(bitmasks, BitmapMasks) assert isinstance(target_sz, tuple) batch_size = len(bitmasks) num_masks = len(bitmasks[0]) results = [] for level_inx in range(num_masks): kernel = [] for batch_inx in range(batch_size): mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) # hxw mask_sz = mask.shape # left, right, top, bottom pad = [ 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] ] mask = F.pad(mask, pad, mode='constant', value=0) kernel.append(mask) kernel = torch.stack(kernel) results.append(kernel) return results
[docs] def forward(self, preds, downsample_ratio, gt_kernels, gt_mask): """Compute PANet loss. Args: preds (Tensor): The output tensor of size :math:`(N, 6, H, W)`. downsample_ratio (float): The downsample ratio between preds and the input img. gt_kernels (list[BitmapMasks]): The kernel list with each element being the text kernel mask for one img. gt_mask (list[BitmapMasks]): The effective mask list with each element being the effective mask for one img. Returns: dict: A loss dict with ``loss_text``, ``loss_kernel``, ``loss_aggregation`` and ``loss_discrimination``. """ assert check_argument.is_type_list(gt_kernels, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert isinstance(downsample_ratio, float) pred_texts = preds[:, 0, :, :] pred_kernels = preds[:, 1, :, :] inst_embed = preds[:, 2:, :, :] feature_sz = preds.size() mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask} gt = {} for key, value in mapping.items(): gt[key] = value gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) gt[key] = [item.to(preds.device) for item in gt[key]] loss_aggrs, loss_discrs = self.aggregation_discrimination_loss( gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed) # compute text loss sampled_mask = self.ohem_batch(pred_texts.detach(), gt['gt_kernels'][0], gt['gt_mask'][0]) loss_texts = self.dice_loss_with_logits(pred_texts, gt['gt_kernels'][0], sampled_mask) # compute kernel loss sampled_masks_kernel = (gt['gt_kernels'][0] > 0.5).float() * ( gt['gt_mask'][0].float()) loss_kernels = self.dice_loss_with_logits(pred_kernels, gt['gt_kernels'][1], sampled_masks_kernel) losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs] if self.reduction == 'mean': losses = [item.mean() for item in losses] elif self.reduction == 'sum': losses = [item.sum() for item in losses] else: raise NotImplementedError coefs = [1, self.alpha, self.beta, self.beta] losses = [item * scale for item, scale in zip(losses, coefs)] results = dict() results.update( loss_text=losses[0], loss_kernel=losses[1], loss_aggregation=losses[2], loss_discrimination=losses[3]) return results
[docs] def aggregation_discrimination_loss(self, gt_texts, gt_kernels, inst_embeds): """Compute the aggregation and discrimnative losses. Args: gt_texts (Tensor): The ground truth text mask of size :math:`(N, 1, H, W)`. gt_kernels (Tensor): The ground truth text kernel mask of size :math:`(N, 1, H, W)`. inst_embeds(Tensor): The text instance embedding tensor of size :math:`(N, 1, H, W)`. Returns: (Tensor, Tensor): A tuple of aggregation loss and discriminative loss before reduction. """ batch_size = gt_texts.size()[0] gt_texts = gt_texts.contiguous().reshape(batch_size, -1) gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1) assert inst_embeds.shape[1] == 4 inst_embeds = inst_embeds.contiguous().reshape(batch_size, 4, -1) loss_aggrs = [] loss_discrs = [] for text, kernel, embed in zip(gt_texts, gt_kernels, inst_embeds): # for each image text_num = int(text.max().item()) loss_aggr_img = [] kernel_avgs = [] select_num = self.speedup_bbox_thr if 0 < select_num < text_num: inds = np.random.choice( text_num, select_num, replace=False) + 1 else: inds = range(1, text_num + 1) for i in inds: # for each text instance kernel_i = (kernel == i) # 0.2ms if kernel_i.sum() == 0 or (text == i).sum() == 0: # 0.2ms continue # compute G_Ki in Eq (2) avg = embed[:, kernel_i].mean(1) # 0.5ms kernel_avgs.append(avg) embed_i = embed[:, text == i] # 0.6ms # ||F(p) - G(K_i)|| - delta_aggregation, shape: nums distance = (embed_i - avg.reshape(4, 1)).norm( # 0.5ms 2, dim=0) - self.delta_aggregation # compute D(p,K_i) in Eq (2) hinge = torch.max( distance, torch.tensor(0, device=distance.device, dtype=torch.float)).pow(2) aggr = torch.log(hinge + 1).mean() loss_aggr_img.append(aggr) num_inst = len(loss_aggr_img) if num_inst > 0: loss_aggr_img = torch.stack(loss_aggr_img).mean() else: loss_aggr_img = torch.tensor( 0, device=gt_texts.device, dtype=torch.float) loss_aggrs.append(loss_aggr_img) loss_discr_img = 0 for avg_i, avg_j in itertools.combinations(kernel_avgs, 2): # delta_discrimination - ||G(K_i) - G(K_j)|| distance_ij = self.delta_discrimination - (avg_i - avg_j).norm(2) # D(K_i,K_j) D_ij = torch.max( distance_ij, torch.tensor( 0, device=distance_ij.device, dtype=torch.float)).pow(2) loss_discr_img += torch.log(D_ij + 1) if num_inst > 1: loss_discr_img /= (num_inst * (num_inst - 1)) else: loss_discr_img = torch.tensor( 0, device=gt_texts.device, dtype=torch.float) if num_inst == 0: warnings.warn('num of instance is 0') loss_discrs.append(loss_discr_img) return torch.stack(loss_aggrs), torch.stack(loss_discrs)
def dice_loss_with_logits(self, pred, target, mask): smooth = 0.001 pred = torch.sigmoid(pred) target[target <= 0.5] = 0 target[target > 0.5] = 1 pred = pred.contiguous().view(pred.size()[0], -1) target = target.contiguous().view(target.size()[0], -1) mask = mask.contiguous().view(mask.size()[0], -1) pred = pred * mask target = target * mask a = torch.sum(pred * target, 1) + smooth b = torch.sum(pred * pred, 1) + smooth c = torch.sum(target * target, 1) + smooth d = (2 * a) / (b + c) return 1 - d
[docs] def ohem_img(self, text_score, gt_text, gt_mask): """Sample the top-k maximal negative samples and all positive samples. Args: text_score (Tensor): The text score of size :math:`(H, W)`. gt_text (Tensor): The ground truth text mask of size :math:`(H, W)`. gt_mask (Tensor): The effective region mask of size :math:`(H, W)`. Returns: Tensor: The sampled pixel mask of size :math:`(H, W)`. """ assert isinstance(text_score, torch.Tensor) assert isinstance(gt_text, torch.Tensor) assert isinstance(gt_mask, torch.Tensor) assert len(text_score.shape) == 2 assert text_score.shape == gt_text.shape assert gt_text.shape == gt_mask.shape pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)( torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item()) neg_num = (int)(torch.sum(gt_text <= 0.5).item()) neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) if pos_num == 0 or neg_num == 0: warnings.warn('pos_num = 0 or neg_num = 0') return gt_mask.bool() neg_score = text_score[gt_text <= 0.5] neg_score_sorted, _ = torch.sort(neg_score, descending=True) threshold = neg_score_sorted[neg_num - 1] sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * ( gt_mask > 0.5) return sampled_mask
[docs] def ohem_batch(self, text_scores, gt_texts, gt_mask): """OHEM sampling for a batch of imgs. Args: text_scores (Tensor): The text scores of size :math:`(H, W)`. gt_texts (Tensor): The gt text masks of size :math:`(H, W)`. gt_mask (Tensor): The gt effective mask of size :math:`(H, W)`. Returns: Tensor: The sampled mask of size :math:`(H, W)`. """ assert isinstance(text_scores, torch.Tensor) assert isinstance(gt_texts, torch.Tensor) assert isinstance(gt_mask, torch.Tensor) assert len(text_scores.shape) == 3 assert text_scores.shape == gt_texts.shape assert gt_texts.shape == gt_mask.shape sampled_masks = [] for i in range(text_scores.shape[0]): sampled_masks.append( self.ohem_img(text_scores[i], gt_texts[i], gt_mask[i])) sampled_masks = torch.stack(sampled_masks) return sampled_masks
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.