

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textdet.dense_heads.fce_head

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch.nn as nn
from mmcv.runner import BaseModule
from mmdet.core import multi_apply

from mmocr.models.builder import HEADS
from ..postprocess.utils import poly_nms
from .head_mixin import HeadMixin

[docs]@HEADS.register_module() class FCEHead(HeadMixin, BaseModule): """The class for implementing FCENet head. FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text Detection <>`_ Args: in_channels (int): The number of input channels. scales (list[int]) : The scale of each layer. fourier_degree (int) : The maximum Fourier transform degree k. nms_thr (float) : The threshold of nms. loss (dict): Config of loss for FCENet. postprocessor (dict): Config of postprocessor for FCENet. """ def __init__(self, in_channels, scales, fourier_degree=5, nms_thr=0.1, loss=dict(type='FCELoss', num_sample=50), postprocessor=dict( type='FCEPostprocessor', text_repr_type='poly', num_reconstr_points=50, alpha=1.0, beta=2.0, score_thr=0.3), train_cfg=None, test_cfg=None, init_cfg=dict( type='Normal', mean=0, std=0.01, override=[ dict(name='out_conv_cls'), dict(name='out_conv_reg') ]), **kwargs): old_keys = [ 'text_repr_type', 'decoding_type', 'num_reconstr_points', 'alpha', 'beta', 'score_thr' ] for key in old_keys: if kwargs.get(key, None): postprocessor[key] = kwargs.get(key) warnings.warn( f'{key} is deprecated, please specify ' 'it in postprocessor config dict. See ' '' ' for details.', UserWarning) if kwargs.get('num_sample', None): loss['num_sample'] = kwargs.get('num_sample') warnings.warn( 'num_sample is deprecated, please specify ' 'it in loss config dict. See ' '' ' for details.', UserWarning) BaseModule.__init__(self, init_cfg=init_cfg) loss['fourier_degree'] = fourier_degree postprocessor['fourier_degree'] = fourier_degree postprocessor['nms_thr'] = nms_thr HeadMixin.__init__(self, loss, postprocessor) assert isinstance(in_channels, int) self.downsample_ratio = 1.0 self.in_channels = in_channels self.scales = scales self.fourier_degree = fourier_degree self.nms_thr = nms_thr self.train_cfg = train_cfg self.test_cfg = test_cfg self.out_channels_cls = 4 self.out_channels_reg = (2 * self.fourier_degree + 1) * 2 self.out_conv_cls = nn.Conv2d( self.in_channels, self.out_channels_cls, kernel_size=3, stride=1, padding=1) self.out_conv_reg = nn.Conv2d( self.in_channels, self.out_channels_reg, kernel_size=3, stride=1, padding=1)
[docs] def forward(self, feats): """ Args: feats (list[Tensor]): Each tensor has the shape of :math:`(N, C_i, H_i, W_i)`. Returns: list[[Tensor, Tensor]]: Each pair of tensors corresponds to the classification result and regression result computed from the input tensor with the same index. They have the shapes of :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, W_i)`. """ cls_res, reg_res = multi_apply(self.forward_single, feats) level_num = len(cls_res) preds = [[cls_res[i], reg_res[i]] for i in range(level_num)] return preds
def forward_single(self, x): cls_predict = self.out_conv_cls(x) reg_predict = self.out_conv_reg(x) return cls_predict, reg_predict
[docs] def get_boundary(self, score_maps, img_metas, rescale): assert len(score_maps) == len(self.scales) boundaries = [] for idx, score_map in enumerate(score_maps): scale = self.scales[idx] boundaries = boundaries + self._get_boundary_single( score_map, scale) # nms boundaries = poly_nms(boundaries, self.nms_thr) if rescale: boundaries = self.resize_boundary( boundaries, 1.0 / img_metas[0]['scale_factor']) results = dict(boundary_result=boundaries) return results
def _get_boundary_single(self, score_map, scale): assert len(score_map) == 2 assert score_map[1].shape[1] == 4 * self.fourier_degree + 2 return self.postprocessor(score_map, scale)
Read the Docs v: v0.6.3
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.