Source code for mmocr.models.textdet.necks.fpn_cat

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16

from mmdet.models.builder import NECKS


[docs]@NECKS.register_module() class FPNC(nn.Module): """FPN-like fusion module in Real-time Scene Text Detection with Differentiable Binarization. This was partially adapted from https://github.com/MhLiao/DB and https://github.com/WenmuZhou/DBNet.pytorch """ def __init__(self, in_channels, lateral_channels=256, out_channels=64, bias_on_lateral=False, bn_re_on_lateral=False, bias_on_smooth=False, bn_re_on_smooth=False, conv_after_concat=False): super(FPNC, self).__init__() assert isinstance(in_channels, list) self.in_channels = in_channels self.lateral_channels = lateral_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.bn_re_on_lateral = bn_re_on_lateral self.bn_re_on_smooth = bn_re_on_smooth self.conv_after_concat = conv_after_concat self.lateral_convs = nn.ModuleList() self.smooth_convs = nn.ModuleList() self.num_outs = self.num_ins for i in range(self.num_ins): norm_cfg = None act_cfg = None if self.bn_re_on_lateral: norm_cfg = dict(type='BN') act_cfg = dict(type='ReLU') l_conv = ConvModule( in_channels[i], lateral_channels, 1, bias=bias_on_lateral, conv_cfg=None, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) norm_cfg = None act_cfg = None if self.bn_re_on_smooth: norm_cfg = dict(type='BN') act_cfg = dict(type='ReLU') smooth_conv = ConvModule( lateral_channels, out_channels, 3, bias=bias_on_smooth, padding=1, conv_cfg=None, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.smooth_convs.append(smooth_conv) if self.conv_after_concat: norm_cfg = dict(type='BN') act_cfg = dict(type='ReLU') self.out_conv = ConvModule( out_channels * self.num_outs, out_channels * self.num_outs, 3, padding=1, conv_cfg=None, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) # default init_weights for conv(msra) and norm in ConvModule
[docs] def init_weights(self): """Initialize the weights of FPN module.""" for m in self.lateral_convs: m.init_weights() for m in self.smooth_convs: m.init_weights() if self.conv_after_concat: self.out_conv.init_weights()
@auto_fp16() def forward(self, inputs): assert len(inputs) == len(self.in_channels) # build laterals laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] used_backbone_levels = len(laterals) # build top-down path for i in range(used_backbone_levels - 1, 0, -1): prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] += F.interpolate( laterals[i], size=prev_shape, mode='nearest') # build outputs # part 1: from original levels outs = [ self.smooth_convs[i](laterals[i]) for i in range(used_backbone_levels) ] for i in range(len(outs)): scale = 2**i outs[i] = F.interpolate( outs[i], scale_factor=scale, mode='nearest') out = torch.cat(outs, dim=1) if self.conv_after_concat: out = self.out_conv(out) return out