Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textdet.losses.pan_loss
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import warnings
import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument
[docs]@LOSSES.register_module()
class PANLoss(nn.Module):
"""The class for implementing PANet loss. This was partially adapted from
https://github.com/WenmuZhou/PAN.pytorch.
PANet: `Efficient and Accurate Arbitrary-
Shaped Text Detection with Pixel Aggregation Network
<https://arxiv.org/abs/1908.05900>`_.
Args:
alpha (float): The kernel loss coef.
beta (float): The aggregation and discriminative loss coef.
delta_aggregation (float): The constant for aggregation loss.
delta_discrimination (float): The constant for discriminative loss.
ohem_ratio (float): The negative/positive ratio in ohem.
reduction (str): The way to reduce the loss.
speedup_bbox_thr (int): Speed up if speedup_bbox_thr > 0
and < bbox num.
"""
def __init__(self,
alpha=0.5,
beta=0.25,
delta_aggregation=0.5,
delta_discrimination=3,
ohem_ratio=3,
reduction='mean',
speedup_bbox_thr=-1):
super().__init__()
assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']"
self.alpha = alpha
self.beta = beta
self.delta_aggregation = delta_aggregation
self.delta_discrimination = delta_discrimination
self.ohem_ratio = ohem_ratio
self.reduction = reduction
self.speedup_bbox_thr = speedup_bbox_thr
[docs] def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
for one img.
target_sz (tuple(int, int)): The target tensor of size
:math:`(H, W)`.
Returns:
list[Tensor]: The list of kernel tensors. Each element stands for
one kernel level.
"""
assert check_argument.is_type_list(bitmasks, BitmapMasks)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_masks = len(bitmasks[0])
results = []
for level_inx in range(num_masks):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
# hxw
mask_sz = mask.shape
# left, right, top, bottom
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
results.append(kernel)
return results
[docs] def forward(self, preds, downsample_ratio, gt_kernels, gt_mask):
"""Compute PANet loss.
Args:
preds (Tensor): The output tensor of size :math:`(N, 6, H, W)`.
downsample_ratio (float): The downsample ratio between preds
and the input img.
gt_kernels (list[BitmapMasks]): The kernel list with each element
being the text kernel mask for one img.
gt_mask (list[BitmapMasks]): The effective mask list
with each element being the effective mask for one img.
Returns:
dict: A loss dict with ``loss_text``, ``loss_kernel``,
``loss_aggregation`` and ``loss_discrimination``.
"""
assert check_argument.is_type_list(gt_kernels, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert isinstance(downsample_ratio, float)
pred_texts = preds[:, 0, :, :]
pred_kernels = preds[:, 1, :, :]
inst_embed = preds[:, 2:, :, :]
feature_sz = preds.size()
mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask}
gt = {}
for key, value in mapping.items():
gt[key] = value
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
gt[key] = [item.to(preds.device) for item in gt[key]]
loss_aggrs, loss_discrs = self.aggregation_discrimination_loss(
gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed)
# compute text loss
sampled_mask = self.ohem_batch(pred_texts.detach(),
gt['gt_kernels'][0], gt['gt_mask'][0])
loss_texts = self.dice_loss_with_logits(pred_texts,
gt['gt_kernels'][0],
sampled_mask)
# compute kernel loss
sampled_masks_kernel = (gt['gt_kernels'][0] > 0.5).float() * (
gt['gt_mask'][0].float())
loss_kernels = self.dice_loss_with_logits(pred_kernels,
gt['gt_kernels'][1],
sampled_masks_kernel)
losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs]
if self.reduction == 'mean':
losses = [item.mean() for item in losses]
elif self.reduction == 'sum':
losses = [item.sum() for item in losses]
else:
raise NotImplementedError
coefs = [1, self.alpha, self.beta, self.beta]
losses = [item * scale for item, scale in zip(losses, coefs)]
results = dict()
results.update(
loss_text=losses[0],
loss_kernel=losses[1],
loss_aggregation=losses[2],
loss_discrimination=losses[3])
return results
[docs] def aggregation_discrimination_loss(self, gt_texts, gt_kernels,
inst_embeds):
"""Compute the aggregation and discrimnative losses.
Args:
gt_texts (Tensor): The ground truth text mask of size
:math:`(N, 1, H, W)`.
gt_kernels (Tensor): The ground truth text kernel mask of
size :math:`(N, 1, H, W)`.
inst_embeds(Tensor): The text instance embedding tensor
of size :math:`(N, 1, H, W)`.
Returns:
(Tensor, Tensor): A tuple of aggregation loss and discriminative
loss before reduction.
"""
batch_size = gt_texts.size()[0]
gt_texts = gt_texts.contiguous().reshape(batch_size, -1)
gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1)
assert inst_embeds.shape[1] == 4
inst_embeds = inst_embeds.contiguous().reshape(batch_size, 4, -1)
loss_aggrs = []
loss_discrs = []
for text, kernel, embed in zip(gt_texts, gt_kernels, inst_embeds):
# for each image
text_num = int(text.max().item())
loss_aggr_img = []
kernel_avgs = []
select_num = self.speedup_bbox_thr
if 0 < select_num < text_num:
inds = np.random.choice(
text_num, select_num, replace=False) + 1
else:
inds = range(1, text_num + 1)
for i in inds:
# for each text instance
kernel_i = (kernel == i) # 0.2ms
if kernel_i.sum() == 0 or (text == i).sum() == 0: # 0.2ms
continue
# compute G_Ki in Eq (2)
avg = embed[:, kernel_i].mean(1) # 0.5ms
kernel_avgs.append(avg)
embed_i = embed[:, text == i] # 0.6ms
# ||F(p) - G(K_i)|| - delta_aggregation, shape: nums
distance = (embed_i - avg.reshape(4, 1)).norm( # 0.5ms
2, dim=0) - self.delta_aggregation
# compute D(p,K_i) in Eq (2)
hinge = torch.max(
distance,
torch.tensor(0, device=distance.device,
dtype=torch.float)).pow(2)
aggr = torch.log(hinge + 1).mean()
loss_aggr_img.append(aggr)
num_inst = len(loss_aggr_img)
if num_inst > 0:
loss_aggr_img = torch.stack(loss_aggr_img).mean()
else:
loss_aggr_img = torch.tensor(
0, device=gt_texts.device, dtype=torch.float)
loss_aggrs.append(loss_aggr_img)
loss_discr_img = 0
for avg_i, avg_j in itertools.combinations(kernel_avgs, 2):
# delta_discrimination - ||G(K_i) - G(K_j)||
distance_ij = self.delta_discrimination - (avg_i -
avg_j).norm(2)
# D(K_i,K_j)
D_ij = torch.max(
distance_ij,
torch.tensor(
0, device=distance_ij.device,
dtype=torch.float)).pow(2)
loss_discr_img += torch.log(D_ij + 1)
if num_inst > 1:
loss_discr_img /= (num_inst * (num_inst - 1))
else:
loss_discr_img = torch.tensor(
0, device=gt_texts.device, dtype=torch.float)
if num_inst == 0:
warnings.warn('num of instance is 0')
loss_discrs.append(loss_discr_img)
return torch.stack(loss_aggrs), torch.stack(loss_discrs)
def dice_loss_with_logits(self, pred, target, mask):
smooth = 0.001
pred = torch.sigmoid(pred)
target[target <= 0.5] = 0
target[target > 0.5] = 1
pred = pred.contiguous().view(pred.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1)
mask = mask.contiguous().view(mask.size()[0], -1)
pred = pred * mask
target = target * mask
a = torch.sum(pred * target, 1) + smooth
b = torch.sum(pred * pred, 1) + smooth
c = torch.sum(target * target, 1) + smooth
d = (2 * a) / (b + c)
return 1 - d
[docs] def ohem_img(self, text_score, gt_text, gt_mask):
"""Sample the top-k maximal negative samples and all positive samples.
Args:
text_score (Tensor): The text score of size :math:`(H, W)`.
gt_text (Tensor): The ground truth text mask of size
:math:`(H, W)`.
gt_mask (Tensor): The effective region mask of size :math:`(H, W)`.
Returns:
Tensor: The sampled pixel mask of size :math:`(H, W)`.
"""
assert isinstance(text_score, torch.Tensor)
assert isinstance(gt_text, torch.Tensor)
assert isinstance(gt_mask, torch.Tensor)
assert len(text_score.shape) == 2
assert text_score.shape == gt_text.shape
assert gt_text.shape == gt_mask.shape
pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)(
torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item())
neg_num = (int)(torch.sum(gt_text <= 0.5).item())
neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))
if pos_num == 0 or neg_num == 0:
warnings.warn('pos_num = 0 or neg_num = 0')
return gt_mask.bool()
neg_score = text_score[gt_text <= 0.5]
neg_score_sorted, _ = torch.sort(neg_score, descending=True)
threshold = neg_score_sorted[neg_num - 1]
sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * (
gt_mask > 0.5)
return sampled_mask
[docs] def ohem_batch(self, text_scores, gt_texts, gt_mask):
"""OHEM sampling for a batch of imgs.
Args:
text_scores (Tensor): The text scores of size :math:`(H, W)`.
gt_texts (Tensor): The gt text masks of size :math:`(H, W)`.
gt_mask (Tensor): The gt effective mask of size :math:`(H, W)`.
Returns:
Tensor: The sampled mask of size :math:`(H, W)`.
"""
assert isinstance(text_scores, torch.Tensor)
assert isinstance(gt_texts, torch.Tensor)
assert isinstance(gt_mask, torch.Tensor)
assert len(text_scores.shape) == 3
assert text_scores.shape == gt_texts.shape
assert gt_texts.shape == gt_mask.shape
sampled_masks = []
for i in range(text_scores.shape[0]):
sampled_masks.append(
self.ohem_img(text_scores[i], gt_texts[i], gt_mask[i]))
sampled_masks = torch.stack(sampled_masks)
return sampled_masks