Shortcuts

mmocr.models.textrecog.layers.conv_layer 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_plugin_layer


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding."""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)


def conv1x1(in_planes, out_planes):
    """1x1 convolution with padding."""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)


[文档]class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, use_conv1x1=False, plugins=None): super().__init__() if use_conv1x1: self.conv1 = conv1x1(inplanes, planes) self.conv2 = conv3x3(planes, planes * self.expansion, stride) else: self.conv1 = conv3x3(inplanes, planes, stride) self.conv2 = conv3x3(planes, planes * self.expansion) self.with_plugins = False if plugins: if isinstance(plugins, dict): plugins = [plugins] self.with_plugins = True # collect plugins for conv1/conv2/ self.before_conv1_plugin = [ plugin['cfg'] for plugin in plugins if plugin['position'] == 'before_conv1' ] self.after_conv1_plugin = [ plugin['cfg'] for plugin in plugins if plugin['position'] == 'after_conv1' ] self.after_conv2_plugin = [ plugin['cfg'] for plugin in plugins if plugin['position'] == 'after_conv2' ] self.after_shortcut_plugin = [ plugin['cfg'] for plugin in plugins if plugin['position'] == 'after_shortcut' ] self.planes = planes self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.bn2 = nn.BatchNorm2d(planes * self.expansion) self.downsample = downsample self.stride = stride if self.with_plugins: self.before_conv1_plugin_names = self.make_block_plugins( inplanes, self.before_conv1_plugin) self.after_conv1_plugin_names = self.make_block_plugins( planes, self.after_conv1_plugin) self.after_conv2_plugin_names = self.make_block_plugins( planes, self.after_conv2_plugin) self.after_shortcut_plugin_names = self.make_block_plugins( planes, self.after_shortcut_plugin)
[文档] def make_block_plugins(self, in_channels, plugins): """make plugins for block. Args: in_channels (int): Input channels of plugin. plugins (list[dict]): List of plugins cfg to build. Returns: list[str]: List of the names of plugin. """ assert isinstance(plugins, list) plugin_names = [] for plugin in plugins: plugin = plugin.copy() name, layer = build_plugin_layer( plugin, in_channels=in_channels, out_channels=in_channels, postfix=plugin.pop('postfix', '')) assert not hasattr(self, name), f'duplicate plugin {name}' self.add_module(name, layer) plugin_names.append(name) return plugin_names
def forward_plugin(self, x, plugin_names): out = x for name in plugin_names: out = getattr(self, name)(x) return out
[文档] def forward(self, x): if self.with_plugins: x = self.forward_plugin(x, self.before_conv1_plugin_names) residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv1_plugin_names) out = self.conv2(out) out = self.bn2(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv2_plugin_names) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_shortcut_plugin_names) return out
[文档]class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=False): super().__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d( planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) if downsample: self.downsample = nn.Sequential( nn.Conv2d( inplanes, planes * self.expansion, 1, stride, bias=False), nn.BatchNorm2d(planes * self.expansion), ) else: self.downsample = nn.Sequential()
[文档] def forward(self, x): residual = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out += residual out = self.relu(out) return out
Read the Docs v: latest
Versions
latest
stable
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.