Shortcuts

Source code for mmocr.models.kie.heads.sdmgr_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import BaseModule
from torch import nn
from torch.nn import functional as F

from mmocr.models.builder import HEADS, build_loss


[docs]@HEADS.register_module() class SDMGRHead(BaseModule): def __init__(self, num_chars=92, visual_dim=64, fusion_dim=1024, node_input=32, node_embed=256, edge_input=5, edge_embed=256, num_gnn=2, num_classes=26, loss=dict(type='SDMGRLoss'), bidirectional=False, train_cfg=None, test_cfg=None, init_cfg=dict( type='Normal', override=dict(name='edge_embed'), mean=0, std=0.01)): super().__init__(init_cfg=init_cfg) self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim) self.node_embed = nn.Embedding(num_chars, node_input, 0) hidden = node_embed // 2 if bidirectional else node_embed self.rnn = nn.LSTM( input_size=node_input, hidden_size=hidden, num_layers=1, batch_first=True, bidirectional=bidirectional) self.edge_embed = nn.Linear(edge_input, edge_embed) self.gnn_layers = nn.ModuleList( [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) self.node_cls = nn.Linear(node_embed, num_classes) self.edge_cls = nn.Linear(edge_embed, 2) self.loss = build_loss(loss)
[docs] def forward(self, relations, texts, x=None): node_nums, char_nums = [], [] for text in texts: node_nums.append(text.size(0)) char_nums.append((text > 0).sum(-1)) max_num = max([char_num.max() for char_num in char_nums]) all_nodes = torch.cat([ torch.cat( [text, text.new_zeros(text.size(0), max_num - text.size(1))], -1) for text in texts ]) embed_nodes = self.node_embed(all_nodes.clamp(min=0).long()) rnn_nodes, _ = self.rnn(embed_nodes) nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2]) all_nums = torch.cat(char_nums) valid = all_nums > 0 nodes[valid] = rnn_nodes[valid].gather( 1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand( -1, -1, rnn_nodes.size(-1))).squeeze(1) if x is not None: nodes = self.fusion([x, nodes]) all_edges = torch.cat( [rel.view(-1, rel.size(-1)) for rel in relations]) embed_edges = self.edge_embed(all_edges.float()) embed_edges = F.normalize(embed_edges) for gnn_layer in self.gnn_layers: nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums) node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes) return node_cls, edge_cls
class GNNLayer(nn.Module): def __init__(self, node_dim=256, edge_dim=256): super().__init__() self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim) self.coef_fc = nn.Linear(node_dim, 1) self.out_fc = nn.Linear(node_dim, node_dim) self.relu = nn.ReLU() def forward(self, nodes, edges, nums): start, cat_nodes = 0, [] for num in nums: sample_nodes = nodes[start:start + num] cat_nodes.append( torch.cat([ sample_nodes.unsqueeze(1).expand(-1, num, -1), sample_nodes.unsqueeze(0).expand(num, -1, -1) ], -1).view(num**2, -1)) start += num cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1) cat_nodes = self.relu(self.in_fc(cat_nodes)) coefs = self.coef_fc(cat_nodes) start, residuals = 0, [] for num in nums: residual = F.softmax( -torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 + coefs[start:start + num**2].view(num, num, -1), 1) residuals.append( (residual * cat_nodes[start:start + num**2].view(num, num, -1)).sum(1)) start += num**2 nodes += self.relu(self.out_fc(torch.cat(residuals))) return nodes, cat_nodes class Block(nn.Module): def __init__(self, input_dims, output_dim, mm_dim=1600, chunks=20, rank=15, shared=False, dropout_input=0., dropout_pre_lin=0., dropout_output=0., pos_norm='before_cat'): super().__init__() self.rank = rank self.dropout_input = dropout_input self.dropout_pre_lin = dropout_pre_lin self.dropout_output = dropout_output assert (pos_norm in ['before_cat', 'after_cat']) self.pos_norm = pos_norm # Modules self.linear0 = nn.Linear(input_dims[0], mm_dim) self.linear1 = ( self.linear0 if shared else nn.Linear(input_dims[1], mm_dim)) self.merge_linears0 = nn.ModuleList() self.merge_linears1 = nn.ModuleList() self.chunks = self.chunk_sizes(mm_dim, chunks) for size in self.chunks: ml0 = nn.Linear(size, size * rank) self.merge_linears0.append(ml0) ml1 = ml0 if shared else nn.Linear(size, size * rank) self.merge_linears1.append(ml1) self.linear_out = nn.Linear(mm_dim, output_dim) def forward(self, x): x0 = self.linear0(x[0]) x1 = self.linear1(x[1]) bs = x1.size(0) if self.dropout_input > 0: x0 = F.dropout(x0, p=self.dropout_input, training=self.training) x1 = F.dropout(x1, p=self.dropout_input, training=self.training) x0_chunks = torch.split(x0, self.chunks, -1) x1_chunks = torch.split(x1, self.chunks, -1) zs = [] for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0, self.merge_linears1): m = m0(x0_c) * m1(x1_c) # bs x split_size*rank m = m.view(bs, self.rank, -1) z = torch.sum(m, 1) if self.pos_norm == 'before_cat': z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) z = F.normalize(z) zs.append(z) z = torch.cat(zs, 1) if self.pos_norm == 'after_cat': z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) z = F.normalize(z) if self.dropout_pre_lin > 0: z = F.dropout(z, p=self.dropout_pre_lin, training=self.training) z = self.linear_out(z) if self.dropout_output > 0: z = F.dropout(z, p=self.dropout_output, training=self.training) return z @staticmethod def chunk_sizes(dim, chunks): split_size = (dim + chunks - 1) // chunks sizes_list = [split_size] * chunks sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim) return sizes_list
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.