Shortcuts

Source code for mmocr.models.kie.losses.sdmgr_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.losses import accuracy
from torch import nn

from mmocr.models.builder import LOSSES


[docs]@LOSSES.register_module() class SDMGRLoss(nn.Module): """The implementation the loss of key information extraction proposed in the paper: Spatial Dual-Modality Graph Reasoning for Key Information Extraction. https://arxiv.org/abs/2103.14470. """ def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=-100): super().__init__() self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore) self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1) self.node_weight = node_weight self.edge_weight = edge_weight self.ignore = ignore
[docs] def forward(self, node_preds, edge_preds, gts): node_gts, edge_gts = [], [] for gt in gts: node_gts.append(gt[:, 0]) edge_gts.append(gt[:, 1:].contiguous().view(-1)) node_gts = torch.cat(node_gts).long() edge_gts = torch.cat(edge_gts).long() node_valids = torch.nonzero( node_gts != self.ignore, as_tuple=False).view(-1) edge_valids = torch.nonzero(edge_gts != -1, as_tuple=False).view(-1) return dict( loss_node=self.node_weight * self.loss_node(node_preds, node_gts), loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts), acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]), acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.