Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textrecog.encoders.sar_encoder
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.models.builder import ENCODERS
from .base_encoder import BaseEncoder
[docs]@ENCODERS.register_module()
class SAREncoder(BaseEncoder):
"""Implementation of encoder module in `SAR.
<https://arxiv.org/abs/1811.00751>`_.
Args:
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
enc_do_rnn (float): Dropout probability of RNN layer in encoder.
enc_gru (bool): If True, use GRU, else LSTM in encoder.
d_model (int): Dim :math:`D_i` of channels from backbone.
d_enc (int): Dim :math:`D_m` of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
enc_bi_rnn=False,
enc_do_rnn=0.0,
enc_gru=False,
d_model=512,
d_enc=512,
mask=True,
init_cfg=[
dict(type='Xavier', layer='Conv2d'),
dict(type='Uniform', layer='BatchNorm2d')
],
**kwargs):
super().__init__(init_cfg=init_cfg)
assert isinstance(enc_bi_rnn, bool)
assert isinstance(enc_do_rnn, (int, float))
assert 0 <= enc_do_rnn < 1.0
assert isinstance(enc_gru, bool)
assert isinstance(d_model, int)
assert isinstance(d_enc, int)
assert isinstance(mask, bool)
self.enc_bi_rnn = enc_bi_rnn
self.enc_do_rnn = enc_do_rnn
self.mask = mask
# LSTM Encoder
kwargs = dict(
input_size=d_model,
hidden_size=d_enc,
num_layers=2,
batch_first=True,
dropout=enc_do_rnn,
bidirectional=enc_bi_rnn)
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
else:
self.rnn_encoder = nn.LSTM(**kwargs)
# global feature transformation
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
[docs] def forward(self, feat, img_metas=None):
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
img_metas (dict): A dict that contains meta information of input
images. Preferably with the key ``valid_ratio``.
Returns:
Tensor: A tensor of shape :math:`(N, D_m)`.
"""
if img_metas is not None:
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
h_feat = feat.size(2)
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
feat_v = feat_v.squeeze(2) # bsz * C * W
feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
if valid_ratios is not None:
valid_hf = []
T = holistic_feat.size(1)
for i, valid_ratio in enumerate(valid_ratios):
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = torch.stack(valid_hf, dim=0)
else:
valid_hf = holistic_feat[:, -1, :] # bsz * C
holistic_feat = self.linear(valid_hf) # bsz * C
return holistic_feat