Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textrecog.decoders.nrtr_decoder

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import ModuleList

from mmocr.models.builder import DECODERS
from mmocr.models.common import PositionalEncoding, TFDecoderLayer
from .base_decoder import BaseDecoder


[docs]@DECODERS.register_module() class NRTRDecoder(BaseDecoder): """Transformer Decoder block with self attention mechanism. Args: n_layers (int): Number of attention layers. d_embedding (int): Language embedding dimension. n_head (int): Number of parallel attention heads. d_k (int): Dimension of the key vector. d_v (int): Dimension of the value vector. d_model (int): Dimension :math:`D_m` of the input from previous model. d_inner (int): Hidden dimension of feedforward layers. n_position (int): Length of the positional encoding vector. Must be greater than ``max_seq_len``. dropout (float): Dropout rate. num_classes (int): Number of output classes :math:`C`. max_seq_len (int): Maximum output sequence length :math:`T`. start_idx (int): The index of `<SOS>`. padding_idx (int): The index of `<PAD>`. init_cfg (dict or list[dict], optional): Initialization configs. Warning: This decoder will not predict the final class which is assumed to be `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>` is also ignored by loss as specified in :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`. """ def __init__(self, n_layers=6, d_embedding=512, n_head=8, d_k=64, d_v=64, d_model=512, d_inner=256, n_position=200, dropout=0.1, num_classes=93, max_seq_len=40, start_idx=1, padding_idx=92, init_cfg=None, **kwargs): super().__init__(init_cfg=init_cfg) self.padding_idx = padding_idx self.start_idx = start_idx self.max_seq_len = max_seq_len self.trg_word_emb = nn.Embedding( num_classes, d_embedding, padding_idx=padding_idx) self.position_enc = PositionalEncoding( d_embedding, n_position=n_position) self.dropout = nn.Dropout(p=dropout) self.layer_stack = ModuleList([ TFDecoderLayer( d_model, d_inner, n_head, d_k, d_v, dropout=dropout, **kwargs) for _ in range(n_layers) ]) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) pred_num_class = num_classes - 1 # ignore padding_idx self.classifier = nn.Linear(d_model, pred_num_class) @staticmethod def get_pad_mask(seq, pad_idx): return (seq != pad_idx).unsqueeze(-2)
[docs] @staticmethod def get_subsequent_mask(seq): """For masking out the subsequent info.""" len_s = seq.size(1) subsequent_mask = 1 - torch.triu( torch.ones((len_s, len_s), device=seq.device), diagonal=1) subsequent_mask = subsequent_mask.unsqueeze(0).bool() return subsequent_mask
def _attention(self, trg_seq, src, src_mask=None): trg_embedding = self.trg_word_emb(trg_seq) trg_pos_encoded = self.position_enc(trg_embedding) tgt = self.dropout(trg_pos_encoded) trg_mask = self.get_pad_mask( trg_seq, pad_idx=self.padding_idx) & self.get_subsequent_mask(trg_seq) output = tgt for dec_layer in self.layer_stack: output = dec_layer( output, src, self_attn_mask=trg_mask, dec_enc_attn_mask=src_mask) output = self.layer_norm(output) return output def _get_mask(self, logit, img_metas): valid_ratios = None if img_metas is not None: valid_ratios = [ img_meta.get('valid_ratio', 1.0) for img_meta in img_metas ] N, T, _ = logit.size() mask = None if valid_ratios is not None: mask = logit.new_zeros((N, T)) for i, valid_ratio in enumerate(valid_ratios): valid_width = min(T, math.ceil(T * valid_ratio)) mask[i, :valid_width] = 1 return mask
[docs] def forward_train(self, feat, out_enc, targets_dict, img_metas): r""" Args: feat (None): Unused. out_enc (Tensor): Encoder output of shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``. targets_dict (dict): A dict with the key ``padded_targets``, a tensor of shape :math:`(N, T)`. Each element is the index of a character. img_metas (dict): A dict that contains meta information of input images. Preferably with the key ``valid_ratio``. Returns: Tensor: The raw logit tensor. Shape :math:`(N, T, C)`. """ src_mask = self._get_mask(out_enc, img_metas) targets = targets_dict['padded_targets'].to(out_enc.device) attn_output = self._attention(targets, out_enc, src_mask=src_mask) outputs = self.classifier(attn_output) return outputs
def forward_test(self, feat, out_enc, img_metas): src_mask = self._get_mask(out_enc, img_metas) N = out_enc.size(0) init_target_seq = torch.full((N, self.max_seq_len + 1), self.padding_idx, device=out_enc.device, dtype=torch.long) # bsz * seq_len init_target_seq[:, 0] = self.start_idx outputs = [] for step in range(0, self.max_seq_len): decoder_output = self._attention( init_target_seq, out_enc, src_mask=src_mask) # bsz * seq_len * C step_result = F.softmax( self.classifier(decoder_output[:, step, :]), dim=-1) # bsz * num_classes outputs.append(step_result) _, step_max_index = torch.max(step_result, dim=-1) init_target_seq[:, step + 1] = step_max_index outputs = torch.stack(outputs, dim=1) return outputs
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.