Shortcuts

Source code for mmocr.models.kie.extractors.sdmgr

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
from mmdet.core import bbox2roi
from torch import nn
from torch.nn import functional as F

from mmocr.core import imshow_edge, imshow_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.utils import list_from_file


[docs]@DETECTORS.register_module() class SDMGR(SingleStageDetector): """The implementation of the paper: Spatial Dual-Modality Graph Reasoning for Key Information Extraction. https://arxiv.org/abs/2103.14470. Args: visual_modality (bool): Whether use the visual modality. class_list (None | str): Mapping file of class index to class name. If None, class index will be shown in `show_results`, else class name. """ def __init__(self, backbone, neck=None, bbox_head=None, extractor=dict( type='mmdet.SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7), featmap_strides=[1]), visual_modality=False, train_cfg=None, test_cfg=None, class_list=None, init_cfg=None, openset=False): super().__init__( backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg) self.visual_modality = visual_modality if visual_modality: self.extractor = build_roi_extractor({ **extractor, 'out_channels': self.backbone.base_channels }) self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size']) else: self.extractor = None self.class_list = class_list self.openset = openset
[docs] def forward_train(self, img, img_metas, relations, texts, gt_bboxes, gt_labels): """ Args: img (tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A list of image info dict where each dict contains: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details of the values of these keys, please see :class:`mmdet.datasets.pipelines.Collect`. relations (list[tensor]): Relations between bboxes. texts (list[tensor]): Texts in bboxes. gt_bboxes (list[tensor]): Each item is the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[tensor]): Class indices corresponding to each box. Returns: dict[str, tensor]: A dictionary of loss components. """ x = self.extract_feat(img, gt_bboxes) node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
[docs] def forward_test(self, img, img_metas, relations, texts, gt_bboxes, rescale=False): x = self.extract_feat(img, gt_bboxes) node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) return [ dict( img_metas=img_metas, nodes=F.softmax(node_preds, -1), edges=F.softmax(edge_preds, -1)) ]
[docs] def extract_feat(self, img, gt_bboxes): if self.visual_modality: x = super().extract_feat(img)[-1] feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes))) return feats.view(feats.size(0), -1) return None
[docs] def show_result(self, img, result, boxes, win_name='', show=False, wait_time=0, out_file=None, **kwargs): """Draw `result` on `img`. Args: img (str or tensor): The image to be displayed. result (dict): The results to draw on `img`. boxes (list): Bbox of img. win_name (str): The window name. wait_time (int): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The output filename. Default: None. Returns: img (tensor): Only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() idx_to_cls = {} if self.class_list is not None: for line in list_from_file(self.class_list): class_idx, class_label = line.strip().split() idx_to_cls[class_idx] = class_label # if out_file specified, do not show image in window if out_file is not None: show = False if self.openset: img = imshow_edge( img, result, boxes, show=show, win_name=win_name, wait_time=wait_time, out_file=out_file) else: img = imshow_node( img, result, boxes, idx_to_cls=idx_to_cls, show=show, win_name=win_name, wait_time=wait_time, out_file=out_file) if not (show or out_file): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return img return img
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.