Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textrecog.layers.conv_layer
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_plugin_layer
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding."""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
def conv1x1(in_planes, out_planes):
"""1x1 convolution with padding."""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
[docs]class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
use_conv1x1=False,
plugins=None):
super(BasicBlock, self).__init__()
if use_conv1x1:
self.conv1 = conv1x1(inplanes, planes)
self.conv2 = conv3x3(planes, planes * self.expansion, stride)
else:
self.conv1 = conv3x3(inplanes, planes, stride)
self.conv2 = conv3x3(planes, planes * self.expansion)
self.with_plugins = False
if plugins:
if isinstance(plugins, dict):
plugins = [plugins]
self.with_plugins = True
# collect plugins for conv1/conv2/
self.before_conv1_plugin = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'before_conv1'
]
self.after_conv1_plugin = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv1'
]
self.after_conv2_plugin = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_conv2'
]
self.after_shortcut_plugin = [
plugin['cfg'] for plugin in plugins
if plugin['position'] == 'after_shortcut'
]
self.planes = planes
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.bn2 = nn.BatchNorm2d(planes * self.expansion)
self.downsample = downsample
self.stride = stride
if self.with_plugins:
self.before_conv1_plugin_names = self.make_block_plugins(
inplanes, self.before_conv1_plugin)
self.after_conv1_plugin_names = self.make_block_plugins(
planes, self.after_conv1_plugin)
self.after_conv2_plugin_names = self.make_block_plugins(
planes, self.after_conv2_plugin)
self.after_shortcut_plugin_names = self.make_block_plugins(
planes, self.after_shortcut_plugin)
[docs] def make_block_plugins(self, in_channels, plugins):
"""make plugins for block.
Args:
in_channels (int): Input channels of plugin.
plugins (list[dict]): List of plugins cfg to build.
Returns:
list[str]: List of the names of plugin.
"""
assert isinstance(plugins, list)
plugin_names = []
for plugin in plugins:
plugin = plugin.copy()
name, layer = build_plugin_layer(
plugin,
in_channels=in_channels,
out_channels=in_channels,
postfix=plugin.pop('postfix', ''))
assert not hasattr(self, name), f'duplicate plugin {name}'
self.add_module(name, layer)
plugin_names.append(name)
return plugin_names
def forward_plugin(self, x, plugin_names):
out = x
for name in plugin_names:
out = getattr(self, name)(x)
return out
[docs] def forward(self, x):
if self.with_plugins:
x = self.forward_plugin(x, self.before_conv1_plugin_names)
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv1_plugin_names)
out = self.conv2(out)
out = self.bn2(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_conv2_plugin_names)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
if self.with_plugins:
out = self.forward_plugin(out, self.after_shortcut_plugin_names)
return out
[docs]class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=False):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsample = nn.Sequential(
nn.Conv2d(
inplanes, planes * self.expansion, 1, stride, bias=False),
nn.BatchNorm2d(planes * self.expansion),
)
else:
self.downsample = nn.Sequential()
[docs] def forward(self, x):
residual = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += residual
out = self.relu(out)
return out