Shortcuts

Source code for mmocr.models.textdet.losses.fce_loss

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import multi_apply
from torch import nn

from mmocr.models.builder import LOSSES


[docs]@LOSSES.register_module() class FCELoss(nn.Module): """The class for implementing FCENet loss. FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text Detection <https://arxiv.org/abs/2104.10442>`_ Args: fourier_degree (int) : The maximum Fourier transform degree k. num_sample (int) : The sampling points number of regression loss. If it is too small, fcenet tends to be overfitting. ohem_ratio (float): the negative/positive ratio in OHEM. """ def __init__(self, fourier_degree, num_sample, ohem_ratio=3.): super().__init__() self.fourier_degree = fourier_degree self.num_sample = num_sample self.ohem_ratio = ohem_ratio
[docs] def forward(self, preds, _, p3_maps, p4_maps, p5_maps): """Compute FCENet loss. Args: preds (list[list[Tensor]]): The outer list indicates images in a batch, and the inner list indicates the classification prediction map (with shape :math:`(N, C, H, W)`) and regression map (with shape :math:`(N, C, H, W)`). p3_maps (list[ndarray]): List of leval 3 ground truth target map with shape :math:`(C, H, W)`. p4_maps (list[ndarray]): List of leval 4 ground truth target map with shape :math:`(C, H, W)`. p5_maps (list[ndarray]): List of leval 5 ground truth target map with shape :math:`(C, H, W)`. Returns: dict: A loss dict with ``loss_text``, ``loss_center``, ``loss_reg_x`` and ``loss_reg_y``. """ assert isinstance(preds, list) assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\ 'fourier degree not equal in FCEhead and FCEtarget' device = preds[0][0].device # to tensor gts = [p3_maps, p4_maps, p5_maps] for idx, maps in enumerate(gts): gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device) losses = multi_apply(self.forward_single, preds, gts) loss_tr = torch.tensor(0., device=device).float() loss_tcl = torch.tensor(0., device=device).float() loss_reg_x = torch.tensor(0., device=device).float() loss_reg_y = torch.tensor(0., device=device).float() for idx, loss in enumerate(losses): if idx == 0: loss_tr += sum(loss) elif idx == 1: loss_tcl += sum(loss) elif idx == 2: loss_reg_x += sum(loss) else: loss_reg_y += sum(loss) results = dict( loss_text=loss_tr, loss_center=loss_tcl, loss_reg_x=loss_reg_x, loss_reg_y=loss_reg_y, ) return results
def forward_single(self, pred, gt): cls_pred = pred[0].permute(0, 2, 3, 1).contiguous() reg_pred = pred[1].permute(0, 2, 3, 1).contiguous() gt = gt.permute(0, 2, 3, 1).contiguous() k = 2 * self.fourier_degree + 1 tr_pred = cls_pred[:, :, :, :2].view(-1, 2) tcl_pred = cls_pred[:, :, :, 2:].view(-1, 2) x_pred = reg_pred[:, :, :, 0:k].view(-1, k) y_pred = reg_pred[:, :, :, k:2 * k].view(-1, k) tr_mask = gt[:, :, :, :1].view(-1) tcl_mask = gt[:, :, :, 1:2].view(-1) train_mask = gt[:, :, :, 2:3].view(-1) x_map = gt[:, :, :, 3:3 + k].view(-1, k) y_map = gt[:, :, :, 3 + k:].view(-1, k) tr_train_mask = train_mask * tr_mask device = x_map.device # tr loss loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long()) # tcl loss loss_tcl = torch.tensor(0.).float().to(device) tr_neg_mask = 1 - tr_train_mask if tr_train_mask.sum().item() > 0: loss_tcl_pos = F.cross_entropy( tcl_pred[tr_train_mask.bool()], tcl_mask[tr_train_mask.bool()].long()) loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()], tcl_mask[tr_neg_mask.bool()].long()) loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg # regression loss loss_reg_x = torch.tensor(0.).float().to(device) loss_reg_y = torch.tensor(0.).float().to(device) if tr_train_mask.sum().item() > 0: weight = (tr_mask[tr_train_mask.bool()].float() + tcl_mask[tr_train_mask.bool()].float()) / 2 weight = weight.contiguous().view(-1, 1) ft_x, ft_y = self.fourier2poly(x_map, y_map) ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred) loss_reg_x = torch.mean(weight * F.smooth_l1_loss( ft_x_pre[tr_train_mask.bool()], ft_x[tr_train_mask.bool()], reduction='none')) loss_reg_y = torch.mean(weight * F.smooth_l1_loss( ft_y_pre[tr_train_mask.bool()], ft_y[tr_train_mask.bool()], reduction='none')) return loss_tr, loss_tcl, loss_reg_x, loss_reg_y def ohem(self, predict, target, train_mask): device = train_mask.device pos = (target * train_mask).bool() neg = ((1 - target) * train_mask).bool() n_pos = pos.float().sum() if n_pos.item() > 0: loss_pos = F.cross_entropy( predict[pos], target[pos], reduction='sum') loss_neg = F.cross_entropy( predict[neg], target[neg], reduction='none') n_neg = min( int(neg.float().sum().item()), int(self.ohem_ratio * n_pos.float())) else: loss_pos = torch.tensor(0.).to(device) loss_neg = F.cross_entropy( predict[neg], target[neg], reduction='none') n_neg = 100 if len(loss_neg) > n_neg: loss_neg, _ = torch.topk(loss_neg, n_neg) return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
[docs] def fourier2poly(self, real_maps, imag_maps): """Transform Fourier coefficient maps to polygon maps. Args: real_maps (tensor): A map composed of the real parts of the Fourier coefficients, whose shape is (-1, 2k+1) imag_maps (tensor):A map composed of the imag parts of the Fourier coefficients, whose shape is (-1, 2k+1) Returns x_maps (tensor): A map composed of the x value of the polygon represented by n sample points (xn, yn), whose shape is (-1, n) y_maps (tensor): A map composed of the y value of the polygon represented by n sample points (xn, yn), whose shape is (-1, n) """ device = real_maps.device k_vect = torch.arange( -self.fourier_degree, self.fourier_degree + 1, dtype=torch.float, device=device).view(-1, 1) i_vect = torch.arange( 0, self.num_sample, dtype=torch.float, device=device).view(1, -1) transform_matrix = 2 * np.pi / self.num_sample * torch.mm( k_vect, i_vect) x1 = torch.einsum('ak, kn-> an', real_maps, torch.cos(transform_matrix)) x2 = torch.einsum('ak, kn-> an', imag_maps, torch.sin(transform_matrix)) y1 = torch.einsum('ak, kn-> an', real_maps, torch.sin(transform_matrix)) y2 = torch.einsum('ak, kn-> an', imag_maps, torch.cos(transform_matrix)) x_maps = x1 - x2 y_maps = y1 + y2 return x_maps, y_maps
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.