Shortcuts

Source code for mmocr.models.textrecog.layers.satrn_layers

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from mmocr.models.common import MultiHeadAttention


class SatrnEncoderLayer(BaseModule):
    """"""

    def __init__(self,
                 d_model=512,
                 d_inner=512,
                 n_head=8,
                 d_k=64,
                 d_v=64,
                 dropout=0.1,
                 qkv_bias=False,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(
            n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = LocalityAwareFeedforward(
            d_model, d_inner, dropout=dropout)

    def forward(self, x, h, w, mask=None):
        n, hw, c = x.size()
        residual = x
        x = self.norm1(x)
        x = residual + self.attn(x, x, x, mask)
        residual = x
        x = self.norm2(x)
        x = x.transpose(1, 2).contiguous().view(n, c, h, w)
        x = self.feed_forward(x)
        x = x.view(n, c, hw).transpose(1, 2)
        x = residual + x
        return x


class LocalityAwareFeedforward(BaseModule):
    """Locality-aware feedforward layer in SATRN, see `SATRN.

    <https://arxiv.org/abs/1910.04396>`_
    """

    def __init__(self,
                 d_in,
                 d_hid,
                 dropout=0.1,
                 init_cfg=[
                     dict(type='Xavier', layer='Conv2d'),
                     dict(type='Constant', layer='BatchNorm2d', val=1, bias=0)
                 ]):
        super().__init__(init_cfg=init_cfg)
        self.conv1 = ConvModule(
            d_in,
            d_hid,
            kernel_size=1,
            padding=0,
            bias=False,
            norm_cfg=dict(type='BN'),
            act_cfg=dict(type='ReLU'))

        self.depthwise_conv = ConvModule(
            d_hid,
            d_hid,
            kernel_size=3,
            padding=1,
            bias=False,
            groups=d_hid,
            norm_cfg=dict(type='BN'),
            act_cfg=dict(type='ReLU'))

        self.conv2 = ConvModule(
            d_hid,
            d_in,
            kernel_size=1,
            padding=0,
            bias=False,
            norm_cfg=dict(type='BN'),
            act_cfg=dict(type='ReLU'))

    def forward(self, x):
        x = self.conv1(x)
        x = self.depthwise_conv(x)
        x = self.conv2(x)

        return x


[docs]class Adaptive2DPositionalEncoding(BaseModule): """Implement Adaptive 2D positional encoder for SATRN, see `SATRN <https://arxiv.org/abs/1910.04396>`_ Modified from https://github.com/Media-Smart/vedastr Licensed under the Apache License, Version 2.0 (the "License"); Args: d_hid (int): Dimensions of hidden layer. n_height (int): Max height of the 2D feature output. n_width (int): Max width of the 2D feature output. dropout (int): Size of hidden layers of the model. """ def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1, init_cfg=[dict(type='Xavier', layer='Conv2d')]): super().__init__(init_cfg=init_cfg) h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) h_position_encoder = h_position_encoder.transpose(0, 1) h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1) w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) w_position_encoder = w_position_encoder.transpose(0, 1) w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width) self.register_buffer('h_position_encoder', h_position_encoder) self.register_buffer('w_position_encoder', w_position_encoder) self.h_scale = self.scale_factor_generate(d_hid) self.w_scale = self.scale_factor_generate(d_hid) self.pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(p=dropout) def _get_sinusoid_encoding_table(self, n_position, d_hid): """Sinusoid position encoding table.""" denominator = torch.Tensor([ 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid) ]) denominator = denominator.view(1, -1) pos_tensor = torch.arange(n_position).unsqueeze(-1).float() sinusoid_table = pos_tensor * denominator sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) return sinusoid_table def scale_factor_generate(self, d_hid): scale_factor = nn.Sequential( nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid()) return scale_factor
[docs] def forward(self, x): b, c, h, w = x.size() avg_pool = self.pool(x) h_pos_encoding = \ self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] w_pos_encoding = \ self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] out = x + h_pos_encoding + w_pos_encoding out = self.dropout(out) return out
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.