Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textrecog.layers.satrn_layers
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmocr.models.common import MultiHeadAttention
class SatrnEncoderLayer(BaseModule):
""""""
def __init__(self,
d_model=512,
d_inner=512,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
qkv_bias=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = LocalityAwareFeedforward(
d_model, d_inner, dropout=dropout)
def forward(self, x, h, w, mask=None):
n, hw, c = x.size()
residual = x
x = self.norm1(x)
x = residual + self.attn(x, x, x, mask)
residual = x
x = self.norm2(x)
x = x.transpose(1, 2).contiguous().view(n, c, h, w)
x = self.feed_forward(x)
x = x.view(n, c, hw).transpose(1, 2)
x = residual + x
return x
class LocalityAwareFeedforward(BaseModule):
"""Locality-aware feedforward layer in SATRN, see `SATRN.
<https://arxiv.org/abs/1910.04396>`_
"""
def __init__(self,
d_in,
d_hid,
dropout=0.1,
init_cfg=[
dict(type='Xavier', layer='Conv2d'),
dict(type='Constant', layer='BatchNorm2d', val=1, bias=0)
]):
super().__init__(init_cfg=init_cfg)
self.conv1 = ConvModule(
d_in,
d_hid,
kernel_size=1,
padding=0,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.depthwise_conv = ConvModule(
d_hid,
d_hid,
kernel_size=3,
padding=1,
bias=False,
groups=d_hid,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.conv2 = ConvModule(
d_hid,
d_in,
kernel_size=1,
padding=0,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
def forward(self, x):
x = self.conv1(x)
x = self.depthwise_conv(x)
x = self.conv2(x)
return x
[docs]class Adaptive2DPositionalEncoding(BaseModule):
"""Implement Adaptive 2D positional encoder for SATRN, see
`SATRN <https://arxiv.org/abs/1910.04396>`_
Modified from https://github.com/Media-Smart/vedastr
Licensed under the Apache License, Version 2.0 (the "License");
Args:
d_hid (int): Dimensions of hidden layer.
n_height (int): Max height of the 2D feature output.
n_width (int): Max width of the 2D feature output.
dropout (int): Size of hidden layers of the model.
"""
def __init__(self,
d_hid=512,
n_height=100,
n_width=100,
dropout=0.1,
init_cfg=[dict(type='Xavier', layer='Conv2d')]):
super().__init__(init_cfg=init_cfg)
h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
h_position_encoder = h_position_encoder.transpose(0, 1)
h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1)
w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
w_position_encoder = w_position_encoder.transpose(0, 1)
w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width)
self.register_buffer('h_position_encoder', h_position_encoder)
self.register_buffer('w_position_encoder', w_position_encoder)
self.h_scale = self.scale_factor_generate(d_hid)
self.w_scale = self.scale_factor_generate(d_hid)
self.pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=dropout)
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table."""
denominator = torch.Tensor([
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
])
denominator = denominator.view(1, -1)
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
sinusoid_table = pos_tensor * denominator
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
return sinusoid_table
def scale_factor_generate(self, d_hid):
scale_factor = nn.Sequential(
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid())
return scale_factor
[docs] def forward(self, x):
b, c, h, w = x.size()
avg_pool = self.pool(x)
h_pos_encoding = \
self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
w_pos_encoding = \
self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
out = x + h_pos_encoding + w_pos_encoding
out = self.dropout(out)
return out