Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textdet.losses.fce_loss

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import multi_apply
from torch import nn

from mmocr.models.builder import LOSSES


[docs]@LOSSES.register_module() class FCELoss(nn.Module): """The class for implementing FCENet loss. FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text Detection <https://arxiv.org/abs/2104.10442>`_ Args: fourier_degree (int) : The maximum Fourier transform degree k. num_sample (int) : The sampling points number of regression loss. If it is too small, fcenet tends to be overfitting. ohem_ratio (float): the negative/positive ratio in OHEM. """ def __init__(self, fourier_degree, num_sample, ohem_ratio=3.): super().__init__() self.fourier_degree = fourier_degree self.num_sample = num_sample self.ohem_ratio = ohem_ratio
[docs] def forward(self, preds, _, p3_maps, p4_maps, p5_maps): """Compute FCENet loss. Args: preds (list[list[Tensor]]): The outer list indicates images in a batch, and the inner list indicates the classification prediction map (with shape :math:`(N, C, H, W)`) and regression map (with shape :math:`(N, C, H, W)`). p3_maps (list[ndarray]): List of leval 3 ground truth target map with shape :math:`(C, H, W)`. p4_maps (list[ndarray]): List of leval 4 ground truth target map with shape :math:`(C, H, W)`. p5_maps (list[ndarray]): List of leval 5 ground truth target map with shape :math:`(C, H, W)`. Returns: dict: A loss dict with ``loss_text``, ``loss_center``, ``loss_reg_x`` and ``loss_reg_y``. """ assert isinstance(preds, list) assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\ 'fourier degree not equal in FCEhead and FCEtarget' device = preds[0][0].device # to tensor gts = [p3_maps, p4_maps, p5_maps] for idx, maps in enumerate(gts): gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device) losses = multi_apply(self.forward_single, preds, gts) loss_tr = torch.tensor(0., device=device).float() loss_tcl = torch.tensor(0., device=device).float() loss_reg_x = torch.tensor(0., device=device).float() loss_reg_y = torch.tensor(0., device=device).float() for idx, loss in enumerate(losses): if idx == 0: loss_tr += sum(loss) elif idx == 1: loss_tcl += sum(loss) elif idx == 2: loss_reg_x += sum(loss) else: loss_reg_y += sum(loss) results = dict( loss_text=loss_tr, loss_center=loss_tcl, loss_reg_x=loss_reg_x, loss_reg_y=loss_reg_y, ) return results
def forward_single(self, pred, gt): cls_pred = pred[0].permute(0, 2, 3, 1).contiguous() reg_pred = pred[1].permute(0, 2, 3, 1).contiguous() gt = gt.permute(0, 2, 3, 1).contiguous() k = 2 * self.fourier_degree + 1 tr_pred = cls_pred[:, :, :, :2].view(-1, 2) tcl_pred = cls_pred[:, :, :, 2:].view(-1, 2) x_pred = reg_pred[:, :, :, 0:k].view(-1, k) y_pred = reg_pred[:, :, :, k:2 * k].view(-1, k) tr_mask = gt[:, :, :, :1].view(-1) tcl_mask = gt[:, :, :, 1:2].view(-1) train_mask = gt[:, :, :, 2:3].view(-1) x_map = gt[:, :, :, 3:3 + k].view(-1, k) y_map = gt[:, :, :, 3 + k:].view(-1, k) tr_train_mask = train_mask * tr_mask device = x_map.device # tr loss loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long()) # tcl loss loss_tcl = torch.tensor(0.).float().to(device) tr_neg_mask = 1 - tr_train_mask if tr_train_mask.sum().item() > 0: loss_tcl_pos = F.cross_entropy( tcl_pred[tr_train_mask.bool()], tcl_mask[tr_train_mask.bool()].long()) loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()], tcl_mask[tr_neg_mask.bool()].long()) loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg # regression loss loss_reg_x = torch.tensor(0.).float().to(device) loss_reg_y = torch.tensor(0.).float().to(device) if tr_train_mask.sum().item() > 0: weight = (tr_mask[tr_train_mask.bool()].float() + tcl_mask[tr_train_mask.bool()].float()) / 2 weight = weight.contiguous().view(-1, 1) ft_x, ft_y = self.fourier2poly(x_map, y_map) ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred) loss_reg_x = torch.mean(weight * F.smooth_l1_loss( ft_x_pre[tr_train_mask.bool()], ft_x[tr_train_mask.bool()], reduction='none')) loss_reg_y = torch.mean(weight * F.smooth_l1_loss( ft_y_pre[tr_train_mask.bool()], ft_y[tr_train_mask.bool()], reduction='none')) return loss_tr, loss_tcl, loss_reg_x, loss_reg_y def ohem(self, predict, target, train_mask): device = train_mask.device pos = (target * train_mask).bool() neg = ((1 - target) * train_mask).bool() n_pos = pos.float().sum() if n_pos.item() > 0: loss_pos = F.cross_entropy( predict[pos], target[pos], reduction='sum') loss_neg = F.cross_entropy( predict[neg], target[neg], reduction='none') n_neg = min( int(neg.float().sum().item()), int(self.ohem_ratio * n_pos.float())) else: loss_pos = torch.tensor(0.).to(device) loss_neg = F.cross_entropy( predict[neg], target[neg], reduction='none') n_neg = 100 if len(loss_neg) > n_neg: loss_neg, _ = torch.topk(loss_neg, n_neg) return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
[docs] def fourier2poly(self, real_maps, imag_maps): """Transform Fourier coefficient maps to polygon maps. Args: real_maps (tensor): A map composed of the real parts of the Fourier coefficients, whose shape is (-1, 2k+1) imag_maps (tensor):A map composed of the imag parts of the Fourier coefficients, whose shape is (-1, 2k+1) Returns x_maps (tensor): A map composed of the x value of the polygon represented by n sample points (xn, yn), whose shape is (-1, n) y_maps (tensor): A map composed of the y value of the polygon represented by n sample points (xn, yn), whose shape is (-1, n) """ device = real_maps.device k_vect = torch.arange( -self.fourier_degree, self.fourier_degree + 1, dtype=torch.float, device=device).view(-1, 1) i_vect = torch.arange( 0, self.num_sample, dtype=torch.float, device=device).view(1, -1) transform_matrix = 2 * np.pi / self.num_sample * torch.mm( k_vect, i_vect) x1 = torch.einsum('ak, kn-> an', real_maps, torch.cos(transform_matrix)) x2 = torch.einsum('ak, kn-> an', imag_maps, torch.sin(transform_matrix)) y1 = torch.einsum('ak, kn-> an', real_maps, torch.sin(transform_matrix)) y2 = torch.einsum('ak, kn-> an', imag_maps, torch.cos(transform_matrix)) x_maps = x1 - x2 y_maps = y1 + y2 return x_maps, y_maps
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.