Shortcuts

Source code for mmocr.models.textdet.losses.pan_loss

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import warnings

import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn

from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument


[docs]@LOSSES.register_module() class PANLoss(nn.Module): """The class for implementing PANet loss. This was partially adapted from https://github.com/WenmuZhou/PAN.pytorch. PANet: `Efficient and Accurate Arbitrary- Shaped Text Detection with Pixel Aggregation Network <https://arxiv.org/abs/1908.05900>`_. Args: alpha (float): The kernel loss coef. beta (float): The aggregation and discriminative loss coef. delta_aggregation (float): The constant for aggregation loss. delta_discrimination (float): The constant for discriminative loss. ohem_ratio (float): The negative/positive ratio in ohem. reduction (str): The way to reduce the loss. speedup_bbox_thr (int): Speed up if speedup_bbox_thr > 0 and < bbox num. """ def __init__(self, alpha=0.5, beta=0.25, delta_aggregation=0.5, delta_discrimination=3, ohem_ratio=3, reduction='mean', speedup_bbox_thr=-1): super().__init__() assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']" self.alpha = alpha self.beta = beta self.delta_aggregation = delta_aggregation self.delta_discrimination = delta_discrimination self.ohem_ratio = ohem_ratio self.reduction = reduction self.speedup_bbox_thr = speedup_bbox_thr
[docs] def bitmasks2tensor(self, bitmasks, target_sz): """Convert Bitmasks to tensor. Args: bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is for one img. target_sz (tuple(int, int)): The target tensor of size :math:`(H, W)`. Returns: list[Tensor]: The list of kernel tensors. Each element stands for one kernel level. """ assert check_argument.is_type_list(bitmasks, BitmapMasks) assert isinstance(target_sz, tuple) batch_size = len(bitmasks) num_masks = len(bitmasks[0]) results = [] for level_inx in range(num_masks): kernel = [] for batch_inx in range(batch_size): mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) # hxw mask_sz = mask.shape # left, right, top, bottom pad = [ 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] ] mask = F.pad(mask, pad, mode='constant', value=0) kernel.append(mask) kernel = torch.stack(kernel) results.append(kernel) return results
[docs] def forward(self, preds, downsample_ratio, gt_kernels, gt_mask): """Compute PANet loss. Args: preds (Tensor): The output tensor of size :math:`(N, 6, H, W)`. downsample_ratio (float): The downsample ratio between preds and the input img. gt_kernels (list[BitmapMasks]): The kernel list with each element being the text kernel mask for one img. gt_mask (list[BitmapMasks]): The effective mask list with each element being the effective mask for one img. Returns: dict: A loss dict with ``loss_text``, ``loss_kernel``, ``loss_aggregation`` and ``loss_discrimination``. """ assert check_argument.is_type_list(gt_kernels, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert isinstance(downsample_ratio, float) pred_texts = preds[:, 0, :, :] pred_kernels = preds[:, 1, :, :] inst_embed = preds[:, 2:, :, :] feature_sz = preds.size() mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask} gt = {} for key, value in mapping.items(): gt[key] = value gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) gt[key] = [item.to(preds.device) for item in gt[key]] loss_aggrs, loss_discrs = self.aggregation_discrimination_loss( gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed) # compute text loss sampled_mask = self.ohem_batch(pred_texts.detach(), gt['gt_kernels'][0], gt['gt_mask'][0]) loss_texts = self.dice_loss_with_logits(pred_texts, gt['gt_kernels'][0], sampled_mask) # compute kernel loss sampled_masks_kernel = (gt['gt_kernels'][0] > 0.5).float() * ( gt['gt_mask'][0].float()) loss_kernels = self.dice_loss_with_logits(pred_kernels, gt['gt_kernels'][1], sampled_masks_kernel) losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs] if self.reduction == 'mean': losses = [item.mean() for item in losses] elif self.reduction == 'sum': losses = [item.sum() for item in losses] else: raise NotImplementedError coefs = [1, self.alpha, self.beta, self.beta] losses = [item * scale for item, scale in zip(losses, coefs)] results = dict() results.update( loss_text=losses[0], loss_kernel=losses[1], loss_aggregation=losses[2], loss_discrimination=losses[3]) return results
[docs] def aggregation_discrimination_loss(self, gt_texts, gt_kernels, inst_embeds): """Compute the aggregation and discrimnative losses. Args: gt_texts (Tensor): The ground truth text mask of size :math:`(N, 1, H, W)`. gt_kernels (Tensor): The ground truth text kernel mask of size :math:`(N, 1, H, W)`. inst_embeds(Tensor): The text instance embedding tensor of size :math:`(N, 1, H, W)`. Returns: (Tensor, Tensor): A tuple of aggregation loss and discriminative loss before reduction. """ batch_size = gt_texts.size()[0] gt_texts = gt_texts.contiguous().reshape(batch_size, -1) gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1) assert inst_embeds.shape[1] == 4 inst_embeds = inst_embeds.contiguous().reshape(batch_size, 4, -1) loss_aggrs = [] loss_discrs = [] for text, kernel, embed in zip(gt_texts, gt_kernels, inst_embeds): # for each image text_num = int(text.max().item()) loss_aggr_img = [] kernel_avgs = [] select_num = self.speedup_bbox_thr if 0 < select_num < text_num: inds = np.random.choice( text_num, select_num, replace=False) + 1 else: inds = range(1, text_num + 1) for i in inds: # for each text instance kernel_i = (kernel == i) # 0.2ms if kernel_i.sum() == 0 or (text == i).sum() == 0: # 0.2ms continue # compute G_Ki in Eq (2) avg = embed[:, kernel_i].mean(1) # 0.5ms kernel_avgs.append(avg) embed_i = embed[:, text == i] # 0.6ms # ||F(p) - G(K_i)|| - delta_aggregation, shape: nums distance = (embed_i - avg.reshape(4, 1)).norm( # 0.5ms 2, dim=0) - self.delta_aggregation # compute D(p,K_i) in Eq (2) hinge = torch.max( distance, torch.tensor(0, device=distance.device, dtype=torch.float)).pow(2) aggr = torch.log(hinge + 1).mean() loss_aggr_img.append(aggr) num_inst = len(loss_aggr_img) if num_inst > 0: loss_aggr_img = torch.stack(loss_aggr_img).mean() else: loss_aggr_img = torch.tensor( 0, device=gt_texts.device, dtype=torch.float) loss_aggrs.append(loss_aggr_img) loss_discr_img = 0 for avg_i, avg_j in itertools.combinations(kernel_avgs, 2): # delta_discrimination - ||G(K_i) - G(K_j)|| distance_ij = self.delta_discrimination - (avg_i - avg_j).norm(2) # D(K_i,K_j) D_ij = torch.max( distance_ij, torch.tensor( 0, device=distance_ij.device, dtype=torch.float)).pow(2) loss_discr_img += torch.log(D_ij + 1) if num_inst > 1: loss_discr_img /= (num_inst * (num_inst - 1)) else: loss_discr_img = torch.tensor( 0, device=gt_texts.device, dtype=torch.float) if num_inst == 0: warnings.warn('num of instance is 0') loss_discrs.append(loss_discr_img) return torch.stack(loss_aggrs), torch.stack(loss_discrs)
def dice_loss_with_logits(self, pred, target, mask): smooth = 0.001 pred = torch.sigmoid(pred) target[target <= 0.5] = 0 target[target > 0.5] = 1 pred = pred.contiguous().view(pred.size()[0], -1) target = target.contiguous().view(target.size()[0], -1) mask = mask.contiguous().view(mask.size()[0], -1) pred = pred * mask target = target * mask a = torch.sum(pred * target, 1) + smooth b = torch.sum(pred * pred, 1) + smooth c = torch.sum(target * target, 1) + smooth d = (2 * a) / (b + c) return 1 - d
[docs] def ohem_img(self, text_score, gt_text, gt_mask): """Sample the top-k maximal negative samples and all positive samples. Args: text_score (Tensor): The text score of size :math:`(H, W)`. gt_text (Tensor): The ground truth text mask of size :math:`(H, W)`. gt_mask (Tensor): The effective region mask of size :math:`(H, W)`. Returns: Tensor: The sampled pixel mask of size :math:`(H, W)`. """ assert isinstance(text_score, torch.Tensor) assert isinstance(gt_text, torch.Tensor) assert isinstance(gt_mask, torch.Tensor) assert len(text_score.shape) == 2 assert text_score.shape == gt_text.shape assert gt_text.shape == gt_mask.shape pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)( torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item()) neg_num = (int)(torch.sum(gt_text <= 0.5).item()) neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) if pos_num == 0 or neg_num == 0: warnings.warn('pos_num = 0 or neg_num = 0') return gt_mask.bool() neg_score = text_score[gt_text <= 0.5] neg_score_sorted, _ = torch.sort(neg_score, descending=True) threshold = neg_score_sorted[neg_num - 1] sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * ( gt_mask > 0.5) return sampled_mask
[docs] def ohem_batch(self, text_scores, gt_texts, gt_mask): """OHEM sampling for a batch of imgs. Args: text_scores (Tensor): The text scores of size :math:`(H, W)`. gt_texts (Tensor): The gt text masks of size :math:`(H, W)`. gt_mask (Tensor): The gt effective mask of size :math:`(H, W)`. Returns: Tensor: The sampled mask of size :math:`(H, W)`. """ assert isinstance(text_scores, torch.Tensor) assert isinstance(gt_texts, torch.Tensor) assert isinstance(gt_mask, torch.Tensor) assert len(text_scores.shape) == 3 assert text_scores.shape == gt_texts.shape assert gt_texts.shape == gt_mask.shape sampled_masks = [] for i in range(text_scores.shape[0]): sampled_masks.append( self.ohem_img(text_scores[i], gt_texts[i], gt_mask[i])) sampled_masks = torch.stack(sampled_masks) return sampled_masks
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.