Shortcuts

Source code for mmocr.models.textdet.dense_heads.drrg_head

# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule

from mmocr.models.builder import HEADS, build_loss
from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs
from mmocr.utils import check_argument
from .head_mixin import HeadMixin


[docs]@HEADS.register_module() class DRRGHead(HeadMixin, BaseModule): """The class for DRRG head: `Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_. Args: k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. num_adjacent_linkages (int): The number of linkages when constructing adjacent matrix. node_geo_feat_len (int): The length of embedded geometric feature vector of a component. pooling_scale (float): The spatial scale of rotated RoI-Align. pooling_output_size (tuple(int)): The output size of RRoI-Aligning. nms_thr (float): The locality-aware NMS threshold of text components. min_width (float): The minimum width of text components. max_width (float): The maximum width of text components. comp_shrink_ratio (float): The shrink ratio of text components. comp_ratio (float): The reciprocal of aspect ratio of text components. comp_score_thr (float): The score threshold of text components. text_region_thr (float): The threshold for text region probability map. center_region_thr (float): The threshold for text center region probability map. center_region_area_thr (int): The threshold for filtering small-sized text center region. local_graph_thr (float): The threshold to filter identical local graphs. loss (dict): The config of loss that DRRGHead uses.. postprocessor (dict): Config of postprocessor for Drrg. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, in_channels, k_at_hops=(8, 4), num_adjacent_linkages=3, node_geo_feat_len=120, pooling_scale=1.0, pooling_output_size=(4, 3), nms_thr=0.3, min_width=8.0, max_width=24.0, comp_shrink_ratio=1.03, comp_ratio=0.4, comp_score_thr=0.3, text_region_thr=0.2, center_region_thr=0.2, center_region_area_thr=50, local_graph_thr=0.7, loss=dict(type='DRRGLoss'), postprocessor=dict(type='DRRGPostprocessor', link_thr=0.85), train_cfg=None, test_cfg=None, init_cfg=dict( type='Normal', override=dict(name='out_conv'), mean=0, std=0.01), **kwargs): old_keys = ['text_repr_type', 'decoding_type', 'link_thr'] for key in old_keys: if kwargs.get(key, None): postprocessor[key] = kwargs.get(key) warnings.warn( f'{key} is deprecated, please specify ' 'it in postprocessor config dict. See ' 'https://github.com/open-mmlab/mmocr/pull/640' ' for details.', UserWarning) BaseModule.__init__(self, init_cfg=init_cfg) HeadMixin.__init__(self, loss, postprocessor) assert isinstance(in_channels, int) assert isinstance(k_at_hops, tuple) assert isinstance(num_adjacent_linkages, int) assert isinstance(node_geo_feat_len, int) assert isinstance(pooling_scale, float) assert isinstance(pooling_output_size, tuple) assert isinstance(comp_shrink_ratio, float) assert isinstance(nms_thr, float) assert isinstance(min_width, float) assert isinstance(max_width, float) assert isinstance(comp_ratio, float) assert isinstance(comp_score_thr, float) assert isinstance(text_region_thr, float) assert isinstance(center_region_thr, float) assert isinstance(center_region_area_thr, int) assert isinstance(local_graph_thr, float) self.in_channels = in_channels self.out_channels = 6 self.downsample_ratio = 1.0 self.k_at_hops = k_at_hops self.num_adjacent_linkages = num_adjacent_linkages self.node_geo_feat_len = node_geo_feat_len self.pooling_scale = pooling_scale self.pooling_output_size = pooling_output_size self.comp_shrink_ratio = comp_shrink_ratio self.nms_thr = nms_thr self.min_width = min_width self.max_width = max_width self.comp_ratio = comp_ratio self.comp_score_thr = comp_score_thr self.text_region_thr = text_region_thr self.center_region_thr = center_region_thr self.center_region_area_thr = center_region_area_thr self.local_graph_thr = local_graph_thr self.loss_module = build_loss(loss) self.train_cfg = train_cfg self.test_cfg = test_cfg self.out_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0) self.graph_train = LocalGraphs(self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len, self.pooling_scale, self.pooling_output_size, self.local_graph_thr) self.graph_test = ProposalLocalGraphs( self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len, self.pooling_scale, self.pooling_output_size, self.nms_thr, self.min_width, self.max_width, self.comp_shrink_ratio, self.comp_ratio, self.comp_score_thr, self.text_region_thr, self.center_region_thr, self.center_region_area_thr) pool_w, pool_h = self.pooling_output_size node_feat_len = (pool_w * pool_h) * ( self.in_channels + self.out_channels) + self.node_geo_feat_len self.gcn = GCN(node_feat_len)
[docs] def forward(self, inputs, gt_comp_attribs): """ Args: inputs (Tensor): Shape of :math:`(N, C, H, W)`. gt_comp_attribs (list[ndarray]): The padded text component attributes. Shape: (num_component, 8). Returns: tuple: Returns (pred_maps, (gcn_pred, gt_labels)). - | pred_maps (Tensor): Prediction map with shape :math:`(N, C_{out}, H, W)`. - | gcn_pred (Tensor): Prediction from GCN module, with shape :math:`(N, 2)`. - | gt_labels (Tensor): Ground-truth label with shape :math:`(N, 8)`. """ pred_maps = self.out_conv(inputs) feat_maps = torch.cat([inputs, pred_maps], dim=1) node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train( feat_maps, np.stack(gt_comp_attribs)) gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds) return pred_maps, (gcn_pred, gt_labels)
[docs] def single_test(self, feat_maps): r""" Args: feat_maps (Tensor): Shape of :math:`(N, C, H, W)`. Returns: tuple: Returns (edge, score, text_comps). - | edge (ndarray): The edge array of shape :math:`(N, 2)` where each row is a pair of text component indices that makes up an edge in graph. - | score (ndarray): The score array of shape :math:`(N,)`, corresponding to the edge above. - | text_comps (ndarray): The text components of shape :math:`(N, 9)` where each row corresponds to one box and its score: (x1, y1, x2, y2, x3, y3, x4, y4, score). """ pred_maps = self.out_conv(feat_maps) feat_maps = torch.cat([feat_maps, pred_maps], dim=1) none_flag, graph_data = self.graph_test(pred_maps, feat_maps) (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivot_local_graphs, text_comps) = graph_data if none_flag: return None, None, None gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds) pred_labels = F.softmax(gcn_pred, dim=1) edges = [] scores = [] pivot_local_graphs = pivot_local_graphs.long().squeeze().cpu().numpy() for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs): pivot = pivot_local_graph[0] for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]): neighbor = pivot_local_graph[neighbor_ind.item()] edges.append([pivot, neighbor]) scores.append( pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind, 1].item()) edges = np.asarray(edges) scores = np.asarray(scores) return edges, scores, text_comps
[docs] def get_boundary(self, edges, scores, text_comps, img_metas, rescale): """Compute text boundaries via post processing. Args: edges (ndarray): The edge array of shape N * 2, each row is a pair of text component indices that makes up an edge in graph. scores (ndarray): The edge score array. text_comps (ndarray): The text components. img_metas (list[dict]): The image meta infos. rescale (bool): Rescale boundaries to the original image resolution. Returns: dict: The result dict containing key `boundary_result`. """ assert check_argument.is_type_list(img_metas, dict) assert isinstance(rescale, bool) boundaries = [] if edges is not None: boundaries = self.postprocessor(edges, scores, text_comps) if rescale: boundaries = self.resize_boundary( boundaries, 1.0 / self.downsample_ratio / img_metas[0]['scale_factor']) results = dict(boundary_result=boundaries) return results
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.