Shortcuts

Source code for mmocr.models.textrecog.decoders.sequence_attention_decoder

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmocr.models.builder import DECODERS
from mmocr.models.textrecog.layers import DotProductAttentionLayer
from .base_decoder import BaseDecoder


[docs]@DECODERS.register_module() class SequenceAttentionDecoder(BaseDecoder): """Sequence attention decoder for RobustScanner. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_ Args: num_classes (int): Number of output classes :math:`C`. rnn_layers (int): Number of RNN layers. dim_input (int): Dimension :math:`D_i` of input vector ``feat``. dim_model (int): Dimension :math:`D_m` of the model. Should also be the same as encoder output vector ``out_enc``. max_seq_len (int): Maximum output sequence length :math:`T`. start_idx (int): The index of `<SOS>`. mask (bool): Whether to mask input features according to ``img_meta['valid_ratio']``. padding_idx (int): The index of `<PAD>`. dropout (float): Dropout rate. return_feature (bool): Return feature or logits as the result. encode_value (bool): Whether to use the output of encoder ``out_enc`` as `value` of attention layer. If False, the original feature ``feat`` will be used. init_cfg (dict or list[dict], optional): Initialization configs. Warning: This decoder will not predict the final class which is assumed to be `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` is also ignored by loss as specified in :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. """ def __init__(self, num_classes=None, rnn_layers=2, dim_input=512, dim_model=128, max_seq_len=40, start_idx=0, mask=True, padding_idx=None, dropout=0, return_feature=False, encode_value=False, init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_classes = num_classes self.dim_input = dim_input self.dim_model = dim_model self.return_feature = return_feature self.encode_value = encode_value self.max_seq_len = max_seq_len self.start_idx = start_idx self.mask = mask self.embedding = nn.Embedding( self.num_classes, self.dim_model, padding_idx=padding_idx) self.sequence_layer = nn.LSTM( input_size=dim_model, hidden_size=dim_model, num_layers=rnn_layers, batch_first=True, dropout=dropout) self.attention_layer = DotProductAttentionLayer() self.prediction = None if not self.return_feature: pred_num_classes = num_classes - 1 self.prediction = nn.Linear( dim_model if encode_value else dim_input, pred_num_classes)
[docs] def forward_train(self, feat, out_enc, targets_dict, img_metas): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`. targets_dict (dict): A dict with the key ``padded_targets``, a tensor of shape :math:`(N, T)`. Each element is the index of a character. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if ``return_feature=False``. Otherwise it would be the hidden feature before the prediction projection layer, whose shape is :math:`(N, T, D_m)`. """ valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None targets = targets_dict['padded_targets'].to(feat.device) tgt_embedding = self.embedding(targets) n, c_enc, h, w = out_enc.size() assert c_enc == self.dim_model _, c_feat, _, _ = feat.size() assert c_feat == self.dim_input _, len_q, c_q = tgt_embedding.size() assert c_q == self.dim_model assert len_q <= self.max_seq_len query, _ = self.sequence_layer(tgt_embedding) query = query.permute(0, 2, 1).contiguous() key = out_enc.view(n, c_enc, h * w) if self.encode_value: value = key else: value = feat.view(n, c_feat, h * w) mask = None if valid_ratios is not None: mask = query.new_zeros((n, h, w)) for i, valid_ratio in enumerate(valid_ratios): valid_width = min(w, math.ceil(w * valid_ratio)) mask[i, :, valid_width:] = 1 mask = mask.bool() mask = mask.view(n, h * w) attn_out = self.attention_layer(query, key, value, mask) attn_out = attn_out.permute(0, 2, 1).contiguous() if self.return_feature: return attn_out out = self.prediction(attn_out) return out
[docs] def forward_test(self, feat, out_enc, img_metas): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: The output logit sequence tensor of shape :math:`(N, T, C-1)`. """ seq_len = self.max_seq_len batch_size = feat.size(0) decode_sequence = (feat.new_ones( (batch_size, seq_len)) * self.start_idx).long() outputs = [] for i in range(seq_len): step_out = self.forward_test_step(feat, out_enc, decode_sequence, i, img_metas) outputs.append(step_out) _, max_idx = torch.max(step_out, dim=1, keepdim=False) if i < seq_len - 1: decode_sequence[:, i + 1] = max_idx outputs = torch.stack(outputs, 1) return outputs
[docs] def forward_test_step(self, feat, out_enc, decode_sequence, current_step, img_metas): """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. out_enc (Tensor): Encoder output of shape :math:`(N, D_m, H, W)`. decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that stores history decoding result. current_step (int): Current decoding step. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted tokens at current time step. """ valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] if self.mask else None embed = self.embedding(decode_sequence) n, c_enc, h, w = out_enc.size() assert c_enc == self.dim_model _, c_feat, _, _ = feat.size() assert c_feat == self.dim_input _, _, c_q = embed.size() assert c_q == self.dim_model query, _ = self.sequence_layer(embed) query = query.permute(0, 2, 1).contiguous() key = out_enc.view(n, c_enc, h * w) if self.encode_value: value = key else: value = feat.view(n, c_feat, h * w) mask = None if valid_ratios is not None: mask = query.new_zeros((n, h, w)) for i, valid_ratio in enumerate(valid_ratios): valid_width = min(w, math.ceil(w * valid_ratio)) mask[i, :, valid_width:] = 1 mask = mask.bool() mask = mask.view(n, h * w) # [n, c, l] attn_out = self.attention_layer(query, key, value, mask) out = attn_out[:, :, current_step] if self.return_feature: return out out = self.prediction(out) out = F.softmax(out, dim=-1) return out
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.