Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.kie.extractors.sdmgr

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
from mmdet.core import bbox2roi
from torch import nn
from torch.nn import functional as F

from mmocr.core import imshow_edge, imshow_node
from mmocr.models.builder import DETECTORS, build_roi_extractor
from mmocr.models.common.detectors import SingleStageDetector
from mmocr.utils import list_from_file


[docs]@DETECTORS.register_module() class SDMGR(SingleStageDetector): """The implementation of the paper: Spatial Dual-Modality Graph Reasoning for Key Information Extraction. https://arxiv.org/abs/2103.14470. Args: visual_modality (bool): Whether use the visual modality. class_list (None | str): Mapping file of class index to class name. If None, class index will be shown in `show_results`, else class name. """ def __init__(self, backbone, neck=None, bbox_head=None, extractor=dict( type='mmdet.SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7), featmap_strides=[1]), visual_modality=False, train_cfg=None, test_cfg=None, class_list=None, init_cfg=None, openset=False): super().__init__( backbone, neck, bbox_head, train_cfg, test_cfg, init_cfg=init_cfg) self.visual_modality = visual_modality if visual_modality: self.extractor = build_roi_extractor({ **extractor, 'out_channels': self.backbone.base_channels }) self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size']) else: self.extractor = None self.class_list = class_list self.openset = openset
[docs] def forward_train(self, img, img_metas, relations, texts, gt_bboxes, gt_labels): """ Args: img (tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): A list of image info dict where each dict contains: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details of the values of these keys, please see :class:`mmdet.datasets.pipelines.Collect`. relations (list[tensor]): Relations between bboxes. texts (list[tensor]): Texts in bboxes. gt_bboxes (list[tensor]): Each item is the truth boxes for each image in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[tensor]): Class indices corresponding to each box. Returns: dict[str, tensor]: A dictionary of loss components. """ x = self.extract_feat(img, gt_bboxes) node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
[docs] def forward_test(self, img, img_metas, relations, texts, gt_bboxes, rescale=False): x = self.extract_feat(img, gt_bboxes) node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) return [ dict( img_metas=img_metas, nodes=F.softmax(node_preds, -1), edges=F.softmax(edge_preds, -1)) ]
[docs] def extract_feat(self, img, gt_bboxes): if self.visual_modality: x = super().extract_feat(img)[-1] feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes))) return feats.view(feats.size(0), -1) return None
[docs] def show_result(self, img, result, boxes, win_name='', show=False, wait_time=0, out_file=None, **kwargs): """Draw `result` on `img`. Args: img (str or tensor): The image to be displayed. result (dict): The results to draw on `img`. boxes (list): Bbox of img. win_name (str): The window name. wait_time (int): Value of waitKey param. Default: 0. show (bool): Whether to show the image. Default: False. out_file (str or None): The output filename. Default: None. Returns: img (tensor): Only if not `show` or `out_file`. """ img = mmcv.imread(img) img = img.copy() idx_to_cls = {} if self.class_list is not None: for line in list_from_file(self.class_list): class_idx, class_label = line.strip().split() idx_to_cls[class_idx] = class_label # if out_file specified, do not show image in window if out_file is not None: show = False if self.openset: img = imshow_edge( img, result, boxes, show=show, win_name=win_name, wait_time=wait_time, out_file=out_file) else: img = imshow_node( img, result, boxes, idx_to_cls=idx_to_cls, show=show, win_name=win_name, wait_time=wait_time, out_file=out_file) if not (show or out_file): warnings.warn('show==False and out_file is not specified, only ' 'result image will be returned') return img return img
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.