Shortcuts

Source code for mmocr.models.textdet.heads.db_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from mmengine.model import Sequential
from torch import Tensor

from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils.typing_utils import DetSampleList


[docs]@MODELS.register_module() class DBHead(BaseTextDetHead): """The class for DBNet head. This was partially adapted from https://github.com/MhLiao/DB Args: in_channels (int): The number of input channels. with_bias (bool): Whether add bias in Conv2d layer. Defaults to False. module_loss (dict): Config of loss for dbnet. Defaults to ``dict(type='DBModuleLoss')`` postprocessor (dict): Config of postprocessor for dbnet. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__( self, in_channels: int, with_bias: bool = False, module_loss: Dict = dict(type='DBModuleLoss'), postprocessor: Dict = dict( type='DBPostprocessor', text_repr_type='quad'), init_cfg: Optional[Union[Dict, List[Dict]]] = [ dict(type='Kaiming', layer='Conv'), dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4) ] ) -> None: super().__init__( module_loss=module_loss, postprocessor=postprocessor, init_cfg=init_cfg) assert isinstance(in_channels, int) assert isinstance(with_bias, bool) self.in_channels = in_channels self.binarize = Sequential( nn.Conv2d( in_channels, in_channels // 4, 3, bias=with_bias, padding=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, 1, 2, 2)) self.threshold = self._init_thr(in_channels) self.sigmoid = nn.Sigmoid() def _diff_binarize(self, prob_map: Tensor, thr_map: Tensor, k: int) -> Tensor: """Differential binarization. Args: prob_map (Tensor): Probability map. thr_map (Tensor): Threshold map. k (int): Amplification factor. Returns: Tensor: Binary map. """ return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) def _init_thr(self, inner_channels: int, bias: bool = False) -> nn.ModuleList: """Initialize threshold branch.""" in_channels = inner_channels seq = Sequential( nn.Conv2d( in_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) return seq
[docs] def forward(self, img: Tensor, data_samples: Optional[List[TextDetDataSample]] = None, mode: str = 'predict') -> Tuple[Tensor, Tensor, Tensor]: """ Args: img (Tensor): Shape :math:`(N, C, H, W)`. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. mode (str): Forward mode. It affects the return values. Options are "loss", "predict" and "both". Defaults to "predict". - ``loss``: Run the full network and return the prob logits, threshold map and binary map. - ``predict``: Run the binarzation part and return the prob map only. - ``both``: Run the full network and return prob logits, threshold map, binary map and prob map. Returns: Tensor or tuple(Tensor): Its type depends on ``mode``, read its docstring for details. Each has the shape of :math:`(N, 4H, 4W)`. """ prob_logits = self.binarize(img).squeeze(1) prob_map = self.sigmoid(prob_logits) if mode == 'predict': return prob_map thr_map = self.threshold(img).squeeze(1) binary_map = self._diff_binarize(prob_map, thr_map, k=50).squeeze(1) if mode == 'loss': return prob_logits, thr_map, binary_map return prob_logits, thr_map, binary_map, prob_map
[docs] def loss(self, x: Tuple[Tensor], batch_data_samples: DetSampleList) -> Dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ outs = self(x, batch_data_samples, mode='loss') losses = self.module_loss(outs, batch_data_samples) return losses
[docs] def loss_and_predict(self, x: Tuple[Tensor], batch_data_samples: DetSampleList ) -> Tuple[dict, DetSampleList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. Args: x (tuple[Tensor]): Features from FPN. batch_data_samples (list[:obj:`DetDataSample`]): Each item contains the meta information of each image and corresponding annotations. Returns: tuple: the return value is a tuple contains: - losses: (dict[str, Tensor]): A dictionary of loss components. - predictions (list[:obj:`InstanceData`]): Detection results of each image after the post process. """ outs = self(x, batch_data_samples, mode='both') losses = self.module_loss(outs[:3], batch_data_samples) predictions = self.postprocessor(outs[3], batch_data_samples) return losses, predictions
[docs] def predict(self, x: torch.Tensor, batch_data_samples: DetSampleList) -> DetSampleList: """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: SampleList: Detection results of each image after the post process. """ outs = self(x, batch_data_samples, mode='predict') predictions = self.postprocessor(outs, batch_data_samples) return predictions
Read the Docs v: dev-1.x
Versions
latest
stable
v1.0.1
v1.0.0
0.x
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.