Shortcuts

Source code for mmocr.models.textdet.dense_heads.db_head

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
from mmcv.runner import BaseModule, Sequential

from mmocr.models.builder import HEADS
from .head_mixin import HeadMixin


[docs]@HEADS.register_module() class DBHead(HeadMixin, BaseModule): """The class for DBNet head. This was partially adapted from https://github.com/MhLiao/DB Args: in_channels (int): The number of input channels of the db head. with_bias (bool): Whether add bias in Conv2d layer. downsample_ratio (float): The downsample ratio of ground truths. loss (dict): Config of loss for dbnet. postprocessor (dict): Config of postprocessor for dbnet. """ def __init__( self, in_channels, with_bias=False, downsample_ratio=1.0, loss=dict(type='DBLoss'), postprocessor=dict(type='DBPostprocessor', text_repr_type='quad'), init_cfg=[ dict(type='Kaiming', layer='Conv'), dict(type='Constant', layer='BatchNorm', val=1., bias=1e-4) ], train_cfg=None, test_cfg=None, **kwargs): old_keys = ['text_repr_type', 'decoding_type'] for key in old_keys: if kwargs.get(key, None): postprocessor[key] = kwargs.get(key) warnings.warn( f'{key} is deprecated, please specify ' 'it in postprocessor config dict. See ' 'https://github.com/open-mmlab/mmocr/pull/640' ' for details.', UserWarning) BaseModule.__init__(self, init_cfg=init_cfg) HeadMixin.__init__(self, loss, postprocessor) assert isinstance(in_channels, int) self.in_channels = in_channels self.train_cfg = train_cfg self.test_cfg = test_cfg self.downsample_ratio = downsample_ratio self.binarize = Sequential( nn.Conv2d( in_channels, in_channels // 4, 3, bias=with_bias, padding=1), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) self.threshold = self._init_thr(in_channels) def diff_binarize(self, prob_map, thr_map, k): return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
[docs] def forward(self, inputs): """ Args: inputs (Tensor): Shape (batch_size, hidden_size, h, w). Returns: Tensor: A tensor of the same shape as input. """ prob_map = self.binarize(inputs) thr_map = self.threshold(inputs) binary_map = self.diff_binarize(prob_map, thr_map, k=50) outputs = torch.cat((prob_map, thr_map, binary_map), dim=1) return outputs
def _init_thr(self, inner_channels, bias=False): in_channels = inner_channels seq = Sequential( nn.Conv2d( in_channels, inner_channels // 4, 3, padding=1, bias=bias), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) return seq
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.