Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textrecog.decoders.sar_decoder

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import mmocr.utils as utils
from mmocr.models.builder import DECODERS
from .base_decoder import BaseDecoder


[docs]@DECODERS.register_module() class ParallelSARDecoder(BaseDecoder): """Implementation Parallel Decoder module in `SAR. <https://arxiv.org/abs/1811.00751>`_. Args: num_classes (int): Output class number :math:`C`. channels (list[int]): Network layer channels. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. dec_do_rnn (float): Dropout of RNN layer in decoder. dec_gru (bool): If True, use GRU, else LSTM in decoder. d_model (int): Dim of channels from backbone :math:`D_i`. d_enc (int): Dim of encoder RNN layer :math:`D_m`. d_k (int): Dim of channels of attention module. pred_dropout (float): Dropout probability of prediction layer. max_seq_len (int): Maximum sequence length for decoding. mask (bool): If True, mask padding in feature map. start_idx (int): Index of start token. padding_idx (int): Index of padding token. pred_concat (bool): If True, concat glimpse feature from attention with holistic feature and hidden state. init_cfg (dict or list[dict], optional): Initialization configs. Warning: This decoder will not predict the final class which is assumed to be `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` is also ignored by loss as specified in :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. """ def __init__(self, num_classes=37, enc_bi_rnn=False, dec_bi_rnn=False, dec_do_rnn=0.0, dec_gru=False, d_model=512, d_enc=512, d_k=64, pred_dropout=0.0, max_seq_len=40, mask=True, start_idx=0, padding_idx=92, pred_concat=False, init_cfg=None, **kwargs): super().__init__(init_cfg=init_cfg) self.num_classes = num_classes self.enc_bi_rnn = enc_bi_rnn self.d_k = d_k self.start_idx = start_idx self.max_seq_len = max_seq_len self.mask = mask self.pred_concat = pred_concat encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) # 2D attention layer self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) self.conv3x3_1 = nn.Conv2d( d_model, d_k, kernel_size=3, stride=1, padding=1) self.conv1x1_2 = nn.Linear(d_k, 1) # Decoder RNN layer kwargs = dict( input_size=encoder_rnn_out_size, hidden_size=encoder_rnn_out_size, num_layers=2, batch_first=True, dropout=dec_do_rnn, bidirectional=dec_bi_rnn) if dec_gru: self.rnn_decoder = nn.GRU(**kwargs) else: self.rnn_decoder = nn.LSTM(**kwargs) # Decoder input embedding self.embedding = nn.Embedding( self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) # Prediction layer self.pred_dropout = nn.Dropout(pred_dropout) pred_num_classes = num_classes - 1 # ignore padding_idx in prediction if pred_concat: fc_in_channel = decoder_rnn_out_size + d_model + \ encoder_rnn_out_size else: fc_in_channel = d_model self.prediction = nn.Linear(fc_in_channel, pred_num_classes) def _2d_attention(self, decoder_input, feat, holistic_feat, valid_ratios=None): y = self.rnn_decoder(decoder_input)[0] # y: bsz * (seq_len + 1) * hidden_size attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size bsz, seq_len, attn_size = attn_query.size() attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1) attn_key = self.conv3x3_1(feat) # bsz * attn_size * h * w attn_key = attn_key.unsqueeze(1) # bsz * 1 * attn_size * h * w attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) # bsz * (seq_len + 1) * attn_size * h * w attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous() # bsz * (seq_len + 1) * h * w * attn_size attn_weight = self.conv1x1_2(attn_weight) # bsz * (seq_len + 1) * h * w * 1 bsz, T, h, w, c = attn_weight.size() assert c == 1 if valid_ratios is not None: # cal mask of attention weight attn_mask = torch.zeros_like(attn_weight) for i, valid_ratio in enumerate(valid_ratios): valid_width = min(w, math.ceil(w * valid_ratio)) attn_mask[i, :, :, valid_width:, :] = 1 attn_weight = attn_weight.masked_fill(attn_mask.bool(), float('-inf')) attn_weight = attn_weight.view(bsz, T, -1) attn_weight = F.softmax(attn_weight, dim=-1) attn_weight = attn_weight.view(bsz, T, h, w, c).permute(0, 1, 4, 2, 3).contiguous() attn_feat = torch.sum( torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False) # bsz * (seq_len + 1) * C # linear transformation if self.pred_concat: hf_c = holistic_feat.size(-1) holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c) y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2)) else: y = self.prediction(attn_feat) # bsz * (seq_len + 1) * num_classes if self.train_mode: y = self.pred_dropout(y) return y
[docs] def forward_train(self, feat, out_enc, targets_dict, img_metas): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`. targets_dict (dict): A dict with the key ``padded_targets``, a tensor of shape :math:`(N, T)`. Each element is the index of a character. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`. """ if img_metas is not None: assert utils.is_type_list(img_metas, dict) assert len(img_metas) == feat.size(0) valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None targets = targets_dict['padded_targets'].to(feat.device) tgt_embedding = self.embedding(targets) # bsz * seq_len * emb_dim out_enc = out_enc.unsqueeze(1) # bsz * 1 * emb_dim in_dec = torch.cat((out_enc, tgt_embedding), dim=1) # bsz * (seq_len + 1) * C out_dec = self._2d_attention( in_dec, feat, out_enc, valid_ratios=valid_ratios) # bsz * (seq_len + 1) * num_classes return out_dec[:, 1:, :] # bsz * seq_len * num_classes
[docs] def forward_test(self, feat, out_enc, img_metas): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`. """ if img_metas is not None: assert utils.is_type_list(img_metas, dict) assert len(img_metas) == feat.size(0) valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None seq_len = self.max_seq_len bsz = feat.size(0) start_token = torch.full((bsz, ), self.start_idx, device=feat.device, dtype=torch.long) # bsz start_token = self.embedding(start_token) # bsz * emb_dim start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1) # bsz * seq_len * emb_dim out_enc = out_enc.unsqueeze(1) # bsz * 1 * emb_dim decoder_input = torch.cat((out_enc, start_token), dim=1) # bsz * (seq_len + 1) * emb_dim outputs = [] for i in range(1, seq_len + 1): decoder_output = self._2d_attention( decoder_input, feat, out_enc, valid_ratios=valid_ratios) char_output = decoder_output[:, i, :] # bsz * num_classes char_output = F.softmax(char_output, -1) outputs.append(char_output) _, max_idx = torch.max(char_output, dim=1, keepdim=False) char_embedding = self.embedding(max_idx) # bsz * emb_dim if i < seq_len: decoder_input[:, i + 1, :] = char_embedding outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes return outputs
[docs]@DECODERS.register_module() class SequentialSARDecoder(BaseDecoder): """Implementation Sequential Decoder module in `SAR. <https://arxiv.org/abs/1811.00751>`_. Args: num_classes (int): Output class number :math:`C`. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. dec_do_rnn (float): Dropout of RNN layer in decoder. dec_gru (bool): If True, use GRU, else LSTM in decoder. d_k (int): Dim of conv layers in attention module. d_model (int): Dim of channels from backbone :math:`D_i`. d_enc (int): Dim of encoder RNN layer :math:`D_m`. pred_dropout (float): Dropout probability of prediction layer. max_seq_len (int): Maximum sequence length during decoding. mask (bool): If True, mask padding in feature map. start_idx (int): Index of start token. padding_idx (int): Index of padding token. pred_concat (bool): If True, concat glimpse feature from attention with holistic feature and hidden state. """ def __init__(self, num_classes=37, enc_bi_rnn=False, dec_bi_rnn=False, dec_gru=False, d_k=64, d_model=512, d_enc=512, pred_dropout=0.0, mask=True, max_seq_len=40, start_idx=0, padding_idx=92, pred_concat=False, init_cfg=None, **kwargs): super().__init__(init_cfg=init_cfg) self.num_classes = num_classes self.enc_bi_rnn = enc_bi_rnn self.d_k = d_k self.start_idx = start_idx self.dec_gru = dec_gru self.max_seq_len = max_seq_len self.mask = mask self.pred_concat = pred_concat encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) # 2D attention layer self.conv1x1_1 = nn.Conv2d( decoder_rnn_out_size, d_k, kernel_size=1, stride=1) self.conv3x3_1 = nn.Conv2d( d_model, d_k, kernel_size=3, stride=1, padding=1) self.conv1x1_2 = nn.Conv2d(d_k, 1, kernel_size=1, stride=1) # Decoder rnn layer if dec_gru: self.rnn_decoder_layer1 = nn.GRUCell(encoder_rnn_out_size, encoder_rnn_out_size) self.rnn_decoder_layer2 = nn.GRUCell(encoder_rnn_out_size, encoder_rnn_out_size) else: self.rnn_decoder_layer1 = nn.LSTMCell(encoder_rnn_out_size, encoder_rnn_out_size) self.rnn_decoder_layer2 = nn.LSTMCell(encoder_rnn_out_size, encoder_rnn_out_size) # Decoder input embedding self.embedding = nn.Embedding( self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) # Prediction layer self.pred_dropout = nn.Dropout(pred_dropout) pred_num_class = num_classes - 1 # ignore padding index if pred_concat: fc_in_channel = decoder_rnn_out_size + d_model + d_enc else: fc_in_channel = d_model self.prediction = nn.Linear(fc_in_channel, pred_num_class) def _2d_attention(self, y_prev, feat, holistic_feat, hx1, cx1, hx2, cx2, valid_ratios=None): _, _, h_feat, w_feat = feat.size() if self.dec_gru: hx1 = cx1 = self.rnn_decoder_layer1(y_prev, hx1) hx2 = cx2 = self.rnn_decoder_layer2(hx1, hx2) else: hx1, cx1 = self.rnn_decoder_layer1(y_prev, (hx1, cx1)) hx2, cx2 = self.rnn_decoder_layer2(hx1, (hx2, cx2)) tile_hx2 = hx2.view(hx2.size(0), hx2.size(1), 1, 1) attn_query = self.conv1x1_1(tile_hx2) # bsz * attn_size * 1 * 1 attn_query = attn_query.expand(-1, -1, h_feat, w_feat) attn_key = self.conv3x3_1(feat) attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) attn_weight = self.conv1x1_2(attn_weight) bsz, c, h, w = attn_weight.size() assert c == 1 if valid_ratios is not None: # cal mask of attention weight attn_mask = torch.zeros_like(attn_weight) for i, valid_ratio in enumerate(valid_ratios): valid_width = min(w, math.ceil(w * valid_ratio)) attn_mask[i, :, :, valid_width:] = 1 attn_weight = attn_weight.masked_fill(attn_mask.bool(), float('-inf')) attn_weight = F.softmax(attn_weight.view(bsz, -1), dim=-1) attn_weight = attn_weight.view(bsz, c, h, w) attn_feat = torch.sum( torch.mul(feat, attn_weight), (2, 3), keepdim=False) # n * c # linear transformation if self.pred_concat: y = self.prediction(torch.cat((hx2, attn_feat, holistic_feat), 1)) else: y = self.prediction(attn_feat) return y, hx1, hx1, hx2, hx2
[docs] def forward_train(self, feat, out_enc, targets_dict, img_metas=None): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`. targets_dict (dict): A dict with the key ``padded_targets``, a tensor of shape :math:`(N, T)`. Each element is the index of a character. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`. """ if img_metas is not None: assert utils.is_type_list(img_metas, dict) assert len(img_metas) == feat.size(0) valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None if self.train_mode: targets = targets_dict['padded_targets'].to(feat.device) tgt_embedding = self.embedding(targets) outputs = [] start_token = torch.full((feat.size(0), ), self.start_idx, device=feat.device, dtype=torch.long) start_token = self.embedding(start_token) for i in range(-1, self.max_seq_len): if i == -1: if self.dec_gru: hx1 = cx1 = self.rnn_decoder_layer1(out_enc) hx2 = cx2 = self.rnn_decoder_layer2(hx1) else: hx1, cx1 = self.rnn_decoder_layer1(out_enc) hx2, cx2 = self.rnn_decoder_layer2(hx1) if not self.train_mode: y_prev = start_token else: if self.train_mode: y_prev = tgt_embedding[:, i, :] y, hx1, cx1, hx2, cx2 = self._2d_attention( y_prev, feat, out_enc, hx1, cx1, hx2, cx2, valid_ratios=valid_ratios) if self.train_mode: y = self.pred_dropout(y) else: y = F.softmax(y, -1) _, max_idx = torch.max(y, dim=1, keepdim=False) char_embedding = self.embedding(max_idx) y_prev = char_embedding outputs.append(y) outputs = torch.stack(outputs, 1) return outputs
[docs] def forward_test(self, feat, out_enc, img_metas): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`. """ if img_metas is not None: assert utils.is_type_list(img_metas, dict) assert len(img_metas) == feat.size(0) return self.forward_train(feat, out_enc, None, img_metas)
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.