Shortcuts

Source code for mmocr.models.textrecog.encoders.transformer

# Copyright (c) OpenMMLab. All rights reserved.
import copy

from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import BaseModule, ModuleList

from mmocr.models.builder import ENCODERS
from mmocr.models.common.modules import PositionalEncoding


[docs]@ENCODERS.register_module() class TransformerEncoder(BaseModule): """Implement transformer encoder for text recognition, modified from `<https://github.com/FangShancheng/ABINet>`. Args: n_layers (int): Number of attention layers. n_head (int): Number of parallel attention heads. d_model (int): Dimension :math:`D_m` of the input from previous model. d_inner (int): Hidden dimension of feedforward layers. dropout (float): Dropout rate. max_len (int): Maximum output sequence length :math:`T`. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, n_layers=2, n_head=8, d_model=512, d_inner=2048, dropout=0.1, max_len=8 * 32, init_cfg=None): super().__init__(init_cfg=init_cfg) assert d_model % n_head == 0, 'd_model must be divisible by n_head' self.pos_encoder = PositionalEncoding(d_model, n_position=max_len) encoder_layer = BaseTransformerLayer( operation_order=('self_attn', 'norm', 'ffn', 'norm'), attn_cfgs=dict( type='MultiheadAttention', embed_dims=d_model, num_heads=n_head, attn_drop=dropout, dropout_layer=dict(type='Dropout', drop_prob=dropout), ), ffn_cfgs=dict( type='FFN', embed_dims=d_model, feedforward_channels=d_inner, ffn_drop=dropout, ), norm_cfg=dict(type='LN'), ) self.transformer = ModuleList( [copy.deepcopy(encoder_layer) for _ in range(n_layers)])
[docs] def forward(self, feature): """ Args: feature (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`. Returns: Tensor: Features of shape :math:`(N, D_m, H, W)`. """ n, c, h, w = feature.shape feature = feature.view(n, c, -1).transpose(1, 2) # (n, h*w, c) feature = self.pos_encoder(feature) # (n, h*w, c) feature = feature.transpose(0, 1) # (h*w, n, c) for m in self.transformer: feature = m(feature) feature = feature.permute(1, 2, 0).view(n, c, h, w) return feature
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.