Shortcuts

Source code for mmocr.models.textrecog.fusers.abi_fuser

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import BaseModule

from mmocr.models.builder import FUSERS


[docs]@FUSERS.register_module() class ABIFuser(BaseModule): """Mix and align visual feature and linguistic feature Implementation of language model of `ABINet <https://arxiv.org/abs/1910.04396>`_. Args: d_model (int): Hidden size of input. max_seq_len (int): Maximum text sequence length :math:`T`. num_chars (int): Number of text characters :math:`C`. init_cfg (dict): Specifies the initialization method for model layers. """ def __init__(self, d_model=512, max_seq_len=40, num_chars=90, init_cfg=None, **kwargs): super().__init__(init_cfg=init_cfg) self.max_seq_len = max_seq_len + 1 # additional stop token self.w_att = nn.Linear(2 * d_model, d_model) self.cls = nn.Linear(d_model, num_chars)
[docs] def forward(self, l_feature, v_feature): """ Args: l_feature: (N, T, E) where T is length, N is batch size and d is dim of model. v_feature: (N, T, E) shape the same as l_feature. Returns: A dict with key ``logits`` The logits of shape (N, T, C) where N is batch size, T is length and C is the number of characters. """ f = torch.cat((l_feature, v_feature), dim=2) f_att = torch.sigmoid(self.w_att(f)) output = f_att * v_feature + (1 - f_att) * l_feature logits = self.cls(output) # (N, T, C) return {'logits': logits}
Read the Docs v: v0.4.1
Versions
latest
stable
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.