Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textrecog.decoders.master_decoder

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmcv.runner import ModuleList

from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding
from .base_decoder import BaseDecoder


def clones(module, N):
    """Produce N identical layers."""
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class Embeddings(nn.Module):

    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, *input):
        x = input[0]
        return self.lut(x) * math.sqrt(self.d_model)


[docs]@DECODERS.register_module() class MasterDecoder(BaseDecoder): """Decoder module in `MASTER <https://arxiv.org/abs/1910.02562>`_. Code is partially modified from https://github.com/wenwenyu/MASTER-pytorch. Args: start_idx (int): The index of `<SOS>`. padding_idx (int): The index of `<PAD>`. num_classes (int): Number of text characters :math:`C`. n_layers (int): Number of attention layers. n_head (int): Number of parallel attention heads. d_model (int): Dimension :math:`E` of the input from previous model. feat_size (int): The size of the input feature from previous model, usually :math:`H * W`. d_inner (int): Hidden dimension of feedforward layers. attn_drop (float): Dropout rate of the attention layer. ffn_drop (float): Dropout rate of the feedforward layer. feat_pe_drop (float): Dropout rate of the feature positional encoding layer. max_seq_len (int): Maximum output sequence length :math:`T`. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__( self, start_idx, padding_idx, num_classes=93, n_layers=3, n_head=8, d_model=512, feat_size=6 * 40, d_inner=2048, attn_drop=0., ffn_drop=0., feat_pe_drop=0.2, max_seq_len=30, init_cfg=None, ): super(MasterDecoder, self).__init__(init_cfg=init_cfg) operation_order = ('norm', 'self_attn', 'norm', 'cross_attn', 'norm', 'ffn') decoder_layer = BaseTransformerLayer( operation_order=operation_order, attn_cfgs=dict( type='MultiheadAttention', embed_dims=d_model, num_heads=n_head, attn_drop=attn_drop, dropout_layer=dict(type='Dropout', drop_prob=attn_drop), ), ffn_cfgs=dict( type='FFN', embed_dims=d_model, feedforward_channels=d_inner, ffn_drop=ffn_drop, dropout_layer=dict(type='Dropout', drop_prob=ffn_drop), ), norm_cfg=dict(type='LN'), batch_first=True, ) self.decoder_layers = ModuleList( [copy.deepcopy(decoder_layer) for _ in range(n_layers)]) self.cls = nn.Linear(d_model, num_classes) self.SOS = start_idx self.PAD = padding_idx self.max_seq_len = max_seq_len self.feat_size = feat_size self.n_head = n_head self.embedding = Embeddings(d_model=d_model, vocab=num_classes) self.positional_encoding = PositionalEncoding( d_hid=d_model, n_position=self.max_seq_len + 1) self.feat_positional_encoding = PositionalEncoding( d_hid=d_model, n_position=self.feat_size, dropout=feat_pe_drop) self.norm = nn.LayerNorm(d_model)
[docs] def make_mask(self, tgt, device): """Make mask for self attention. Args: tgt (Tensor): Shape [N, l_tgt] device (torch.Device): Mask device. Returns: Tensor: Mask of shape [N * self.n_head, l_tgt, l_tgt] """ trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3).bool() tgt_len = tgt.size(1) trg_sub_mask = torch.tril( torch.ones((tgt_len, tgt_len), dtype=torch.bool, device=device)) tgt_mask = trg_pad_mask & trg_sub_mask # inverse for mmcv's BaseTransformerLayer tril_mask = tgt_mask.clone() tgt_mask = tgt_mask.float().masked_fill_(tril_mask == 0, -1e9) tgt_mask = tgt_mask.masked_fill_(tril_mask, 0) tgt_mask = tgt_mask.repeat(1, self.n_head, 1, 1) tgt_mask = tgt_mask.view(-1, tgt_len, tgt_len) return tgt_mask
def decode(self, input, feature, src_mask, tgt_mask): x = self.embedding(input) x = self.positional_encoding(x) attn_masks = [tgt_mask, src_mask] for layer in self.decoder_layers: x = layer( query=x, key=feature, value=feature, attn_masks=attn_masks) x = self.norm(x) return self.cls(x) def greedy_forward(self, SOS, feature): input = SOS output = None for _ in range(self.max_seq_len): target_mask = self.make_mask(input, device=feature.device) out = self.decode(input, feature, None, target_mask) output = out prob = F.softmax(out, dim=-1) _, next_word = torch.max(prob, dim=-1) input = torch.cat([input, next_word[:, -1].unsqueeze(-1)], dim=1) return output
[docs] def forward_train(self, feat, out_enc, targets_dict, img_metas=None): """ Args: feat (Tensor): The feature map from backbone of shape :math:`(N, E, H, W)`. out_enc (Tensor): Encoder output. targets_dict (dict): A dict with the key ``padded_targets``, a tensor of shape :math:`(N, T)`. Each element is the index of a character. img_metas: Unused. Returns: Tensor: Raw logit tensor of shape :math:`(N, T, C)`. """ # flatten 2D feature map if len(feat.shape) > 3: b, c, h, w = feat.shape feat = feat.view(b, c, h * w) feat = feat.permute((0, 2, 1)) out_enc = self.feat_positional_encoding(feat) \ if out_enc is None else out_enc device = feat.device if isinstance(targets_dict, dict): padded_targets = targets_dict['padded_targets'].to(device) else: padded_targets = targets_dict.to(device) src_mask = None tgt_mask = self.make_mask(padded_targets, device=out_enc.device) return self.decode(padded_targets, out_enc, src_mask, tgt_mask)
[docs] def forward_test(self, feat, out_enc, img_metas): """ Args: feat (Tensor): The feature map from backbone of shape :math:`(N, E, H, W)`. out_enc (Tensor): Encoder output. img_metas: Unused. Returns: Tensor: Raw logit tensor of shape :math:`(N, T, C)`. """ # flatten 2D feature map if len(feat.shape) > 3: b, c, h, w = feat.shape feat = feat.view(b, c, h * w) feat = feat.permute((0, 2, 1)) out_enc = self.feat_positional_encoding(feat) \ if out_enc is None else out_enc batch_size = out_enc.shape[0] SOS = torch.zeros(batch_size).long().to(out_enc.device) SOS[:] = self.SOS SOS = SOS.unsqueeze(1) output = self.greedy_forward(SOS, out_enc) return output
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.