Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textrecog.backbones.resnet
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule, build_plugin_layer
from mmcv.runner import BaseModule, Sequential
import mmocr.utils as utils
from mmocr.models.builder import BACKBONES
from mmocr.models.textrecog.layers import BasicBlock
[docs]@BACKBONES.register_module()
class ResNet(BaseModule):
"""
Args:
in_channels (int): Number of channels of input image tensor.
stem_channels (list[int]): List of channels in each stem layer. E.g.,
[64, 128] stands for 64 and 128 channels in the first and second
stem layers.
block_cfgs (dict): Configs of block
arch_layers (list[int]): List of Block number for each stage.
arch_channels (list[int]): List of channels for each stage.
strides (Sequence[int] | Sequence[tuple]): Strides of the first block
of each stage.
out_indices (None | Sequence[int]): Indices of output stages. If not
specified, only the last stage will be returned.
stage_plugins (dict): Configs of stage plugins
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
stem_channels,
block_cfgs,
arch_layers,
arch_channels,
strides,
out_indices=None,
plugins=None,
init_cfg=[
dict(type='Xavier', layer='Conv2d'),
dict(type='Constant', val=1, layer='BatchNorm2d'),
]):
super().__init__(init_cfg=init_cfg)
assert isinstance(in_channels, int)
assert isinstance(stem_channels, int) or utils.is_type_list(
stem_channels, int)
assert utils.is_type_list(arch_layers, int)
assert utils.is_type_list(arch_channels, int)
assert utils.is_type_list(strides, tuple) or utils.is_type_list(
strides, int)
assert len(arch_layers) == len(arch_channels) == len(strides)
assert out_indices is None or isinstance(out_indices, (list, tuple))
self.out_indices = out_indices
self._make_stem_layer(in_channels, stem_channels)
self.num_stages = len(arch_layers)
self.use_plugins = False
self.arch_channels = arch_channels
self.res_layers = []
if plugins is not None:
self.plugin_ahead_names = []
self.plugin_after_names = []
self.use_plugins = True
for i, num_blocks in enumerate(arch_layers):
stride = strides[i]
channel = arch_channels[i]
if self.use_plugins:
self._make_stage_plugins(plugins, stage_idx=i)
res_layer = self._make_layer(
block_cfgs=block_cfgs,
inplanes=self.inplanes,
planes=channel,
blocks=num_blocks,
stride=stride,
)
self.inplanes = channel
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
def _make_layer(self, block_cfgs, inplanes, planes, blocks, stride):
layers = []
downsample = None
block_cfgs_ = block_cfgs.copy()
if isinstance(stride, int):
stride = (stride, stride)
if stride[0] != 1 or stride[1] != 1 or inplanes != planes:
downsample = ConvModule(
inplanes,
planes,
1,
stride,
norm_cfg=dict(type='BN'),
act_cfg=None)
if block_cfgs_['type'] == 'BasicBlock':
block = BasicBlock
block_cfgs_.pop('type')
else:
raise ValueError('{} not implement yet'.format(block['type']))
layers.append(
block(
inplanes,
planes,
stride=stride,
downsample=downsample,
**block_cfgs_))
inplanes = planes
for _ in range(1, blocks):
layers.append(block(inplanes, planes, **block_cfgs_))
return Sequential(*layers)
def _make_stem_layer(self, in_channels, stem_channels):
if isinstance(stem_channels, int):
stem_channels = [stem_channels]
stem_layers = []
for _, channels in enumerate(stem_channels):
stem_layer = ConvModule(
in_channels,
channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
in_channels = channels
stem_layers.append(stem_layer)
self.stem_layers = Sequential(*stem_layers)
self.inplanes = stem_channels[-1]
def _make_stage_plugins(self, plugins, stage_idx):
"""Make plugins for ResNet ``stage_idx`` th stage.
Currently we support inserting ``nn.Maxpooling``,
``mmcv.cnn.Convmodule``into the backbone. Originally designed
for ResNet31-like architectures.
Examples:
>>> plugins=[
... dict(cfg=dict(type="Maxpooling", arg=(2,2)),
... stages=(True, True, False, False),
... position='before_stage'),
... dict(cfg=dict(type="Maxpooling", arg=(2,1)),
... stages=(False, False, True, Flase),
... position='before_stage'),
... dict(cfg=dict(
... type='ConvModule',
... kernel_size=3,
... stride=1,
... padding=1,
... norm_cfg=dict(type='BN'),
... act_cfg=dict(type='ReLU')),
... stages=(True, True, True, True),
... position='after_stage')]
Suppose ``stage_idx=1``, the structure of stage would be:
.. code-block:: none
Maxpooling -> A set of Basicblocks -> ConvModule
Args:
plugins (list[dict]): List of plugins cfg to build.
stage_idx (int): Index of stage to build
Returns:
list[dict]: Plugins for current stage
"""
in_channels = self.arch_channels[stage_idx]
self.plugin_ahead_names.append([])
self.plugin_after_names.append([])
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
position = plugin.pop('position', None)
assert stages is None or len(stages) == self.num_stages
if stages[stage_idx]:
if position == 'before_stage':
name, layer = build_plugin_layer(
plugin['cfg'],
f'_before_stage_{stage_idx+1}',
in_channels=in_channels,
out_channels=in_channels)
self.plugin_ahead_names[stage_idx].append(name)
self.add_module(name, layer)
elif position == 'after_stage':
name, layer = build_plugin_layer(
plugin['cfg'],
f'_after_stage_{stage_idx+1}',
in_channels=in_channels,
out_channels=in_channels)
self.plugin_after_names[stage_idx].append(name)
self.add_module(name, layer)
else:
raise ValueError('uncorrect plugin position')
def forward_plugin(self, x, plugin_name):
out = x
for name in plugin_name:
out = getattr(self, name)(out)
return out
[docs] def forward(self, x):
"""
Args: x (Tensor): Image tensor of shape :math:`(N, 3, H, W)`.
Returns:
Tensor or list[Tensor]: Feature tensor. It can be a list of
feature outputs at specific layers if ``out_indices`` is specified.
"""
x = self.stem_layers(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
if not self.use_plugins:
x = res_layer(x)
if self.out_indices and i in self.out_indices:
outs.append(x)
else:
x = self.forward_plugin(x, self.plugin_ahead_names[i])
x = res_layer(x)
x = self.forward_plugin(x, self.plugin_after_names[i])
if self.out_indices and i in self.out_indices:
outs.append(x)
return tuple(outs) if self.out_indices else x