Shortcuts

Source code for mmocr.models.textdet.losses.textsnake_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn

from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument


[docs]@LOSSES.register_module() class TextSnakeLoss(nn.Module): """The class for implementing TextSnake loss. This is partially adapted from https://github.com/princewang1994/TextSnake.pytorch. TextSnake: `A Flexible Representation for Detecting Text of Arbitrary Shapes <https://arxiv.org/abs/1807.01544>`_. Args: ohem_ratio (float): The negative/positive ratio in ohem. """ def __init__(self, ohem_ratio=3.0): super().__init__() self.ohem_ratio = ohem_ratio def balanced_bce_loss(self, pred, gt, mask): assert pred.shape == gt.shape == mask.shape positive = gt * mask negative = (1 - gt) * mask positive_count = int(positive.float().sum()) gt = gt.float() if positive_count > 0: loss = F.binary_cross_entropy(pred, gt, reduction='none') positive_loss = torch.sum(loss * positive.float()) negative_loss = loss * negative.float() negative_count = min( int(negative.float().sum()), int(positive_count * self.ohem_ratio)) else: positive_loss = torch.tensor(0.0, device=pred.device) loss = F.binary_cross_entropy(pred, gt, reduction='none') negative_loss = loss * negative.float() negative_count = 100 negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) balance_loss = (positive_loss + torch.sum(negative_loss)) / ( float(positive_count + negative_count) + 1e-5) return balance_loss
[docs] def bitmasks2tensor(self, bitmasks, target_sz): """Convert Bitmasks to tensor. Args: bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is for one img. target_sz (tuple(int, int)): The target tensor of size :math:`(H, W)`. Returns: list[Tensor]: The list of kernel tensors. Each element stands for one kernel level. """ assert check_argument.is_type_list(bitmasks, BitmapMasks) assert isinstance(target_sz, tuple) batch_size = len(bitmasks) num_masks = len(bitmasks[0]) results = [] for level_inx in range(num_masks): kernel = [] for batch_inx in range(batch_size): mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) # hxw mask_sz = mask.shape # left, right, top, bottom pad = [ 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] ] mask = F.pad(mask, pad, mode='constant', value=0) kernel.append(mask) kernel = torch.stack(kernel) results.append(kernel) return results
[docs] def forward(self, pred_maps, downsample_ratio, gt_text_mask, gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map, gt_cos_map): """ Args: pred_maps (Tensor): The prediction map of shape :math:`(N, 5, H, W)`, where each dimension is the map of "text_region", "center_region", "sin_map", "cos_map", and "radius_map" respectively. downsample_ratio (float): Downsample ratio. gt_text_mask (list[BitmapMasks]): Gold text masks. gt_center_region_mask (list[BitmapMasks]): Gold center region masks. gt_mask (list[BitmapMasks]): Gold general masks. gt_radius_map (list[BitmapMasks]): Gold radius maps. gt_sin_map (list[BitmapMasks]): Gold sin maps. gt_cos_map (list[BitmapMasks]): Gold cos maps. Returns: dict: A loss dict with ``loss_text``, ``loss_center``, ``loss_radius``, ``loss_sin`` and ``loss_cos``. """ assert isinstance(downsample_ratio, float) assert check_argument.is_type_list(gt_text_mask, BitmapMasks) assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) assert check_argument.is_type_list(gt_mask, BitmapMasks) assert check_argument.is_type_list(gt_radius_map, BitmapMasks) assert check_argument.is_type_list(gt_sin_map, BitmapMasks) assert check_argument.is_type_list(gt_cos_map, BitmapMasks) pred_text_region = pred_maps[:, 0, :, :] pred_center_region = pred_maps[:, 1, :, :] pred_sin_map = pred_maps[:, 2, :, :] pred_cos_map = pred_maps[:, 3, :, :] pred_radius_map = pred_maps[:, 4, :, :] feature_sz = pred_maps.size() device = pred_maps.device # bitmask 2 tensor mapping = { 'gt_text_mask': gt_text_mask, 'gt_center_region_mask': gt_center_region_mask, 'gt_mask': gt_mask, 'gt_radius_map': gt_radius_map, 'gt_sin_map': gt_sin_map, 'gt_cos_map': gt_cos_map } gt = {} for key, value in mapping.items(): gt[key] = value if abs(downsample_ratio - 1.0) < 1e-2: gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) else: gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) if key == 'gt_radius_map': gt[key] = [item * downsample_ratio for item in gt[key]] gt[key] = [item.to(device) for item in gt[key]] scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) pred_sin_map = pred_sin_map * scale pred_cos_map = pred_cos_map * scale loss_text = self.balanced_bce_loss( torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], gt['gt_mask'][0]) text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() loss_center_map = F.binary_cross_entropy( torch.sigmoid(pred_center_region), gt['gt_center_region_mask'][0].float(), reduction='none') if int(text_mask.sum()) > 0: loss_center = torch.sum( loss_center_map * text_mask) / torch.sum(text_mask) else: loss_center = torch.tensor(0.0, device=device) center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0]).float() if int(center_mask.sum()) > 0: map_sz = pred_radius_map.size() ones = torch.ones(map_sz, dtype=torch.float, device=device) loss_radius = torch.sum( F.smooth_l1_loss( pred_radius_map / (gt['gt_radius_map'][0] + 1e-2), ones, reduction='none') * center_mask) / torch.sum(center_mask) loss_sin = torch.sum( F.smooth_l1_loss( pred_sin_map, gt['gt_sin_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) loss_cos = torch.sum( F.smooth_l1_loss( pred_cos_map, gt['gt_cos_map'][0], reduction='none') * center_mask) / torch.sum(center_mask) else: loss_radius = torch.tensor(0.0, device=device) loss_sin = torch.tensor(0.0, device=device) loss_cos = torch.tensor(0.0, device=device) results = dict( loss_text=loss_text, loss_center=loss_center, loss_radius=loss_radius, loss_sin=loss_sin, loss_cos=loss_cos) return results
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.