Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textrecog.encoders.sar_encoder

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import mmocr.utils as utils
from mmocr.models.builder import ENCODERS
from .base_encoder import BaseEncoder


[docs]@ENCODERS.register_module() class SAREncoder(BaseEncoder): """Implementation of encoder module in `SAR. <https://arxiv.org/abs/1811.00751>`_. Args: enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. enc_do_rnn (float): Dropout probability of RNN layer in encoder. enc_gru (bool): If True, use GRU, else LSTM in encoder. d_model (int): Dim :math:`D_i` of channels from backbone. d_enc (int): Dim :math:`D_m` of encoder RNN layer. mask (bool): If True, mask padding in RNN sequence. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, enc_bi_rnn=False, enc_do_rnn=0.0, enc_gru=False, d_model=512, d_enc=512, mask=True, init_cfg=[ dict(type='Xavier', layer='Conv2d'), dict(type='Uniform', layer='BatchNorm2d') ], **kwargs): super().__init__(init_cfg=init_cfg) assert isinstance(enc_bi_rnn, bool) assert isinstance(enc_do_rnn, (int, float)) assert 0 <= enc_do_rnn < 1.0 assert isinstance(enc_gru, bool) assert isinstance(d_model, int) assert isinstance(d_enc, int) assert isinstance(mask, bool) self.enc_bi_rnn = enc_bi_rnn self.enc_do_rnn = enc_do_rnn self.mask = mask # LSTM Encoder kwargs = dict( input_size=d_model, hidden_size=d_enc, num_layers=2, batch_first=True, dropout=enc_do_rnn, bidirectional=enc_bi_rnn) if enc_gru: self.rnn_encoder = nn.GRU(**kwargs) else: self.rnn_encoder = nn.LSTM(**kwargs) # global feature transformation encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
[docs] def forward(self, feat, img_metas=None): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: A tensor of shape :math:`(N, D_m)`. """ if img_metas is not None: assert utils.is_type_list(img_metas, dict) assert len(img_metas) == feat.size(0) valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None h_feat = feat.size(2) feat_v = F.max_pool2d( feat, kernel_size=(h_feat, 1), stride=1, padding=0) feat_v = feat_v.squeeze(2) # bsz * C * W feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C if valid_ratios is not None: valid_hf = [] T = holistic_feat.size(1) for i, valid_ratio in enumerate(valid_ratios): valid_step = min(T, math.ceil(T * valid_ratio)) - 1 valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf = torch.stack(valid_hf, dim=0) else: valid_hf = holistic_feat[:, -1, :] # bsz * C holistic_feat = self.linear(valid_hf) # bsz * C return holistic_feat
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.