Shortcuts

Source code for mmocr.models.textdet.heads.drrg_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from lanms import merge_quadrangle_n9 as la_nms
except ImportError:
    la_nms = None
from mmcv.ops import RoIAlignRotated
from mmengine.model import BaseModule
from numpy import ndarray
from torch import Tensor
from torch.nn import init

from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import fill_hole


def normalize_adjacent_matrix(mat: ndarray) -> ndarray:
    """Normalize adjacent matrix for GCN. This code was partially adapted from
    https://github.com/GXYM/DRRG licensed under the MIT license.

    Args:
        mat (ndarray): The adjacent matrix.

    returns:
        ndarray: The normalized adjacent matrix.
    """
    assert mat.ndim == 2
    assert mat.shape[0] == mat.shape[1]

    mat = mat + np.eye(mat.shape[0])
    d = np.sum(mat, axis=0)
    d = np.clip(d, 0, None)
    d_inv = np.power(d, -0.5).flatten()
    d_inv[np.isinf(d_inv)] = 0.0
    d_inv = np.diag(d_inv)
    norm_mat = mat.dot(d_inv).transpose().dot(d_inv)
    return norm_mat


def euclidean_distance_matrix(mat_a: ndarray, mat_b: ndarray) -> ndarray:
    """Calculate the Euclidean distance matrix.

    Args:
        mat_a (ndarray): The point sequence.
        mat_b (ndarray): The point sequence with the same dimensions as mat_a.

    returns:
        ndarray: The Euclidean distance matrix.
    """
    assert mat_a.ndim == 2
    assert mat_b.ndim == 2
    assert mat_a.shape[1] == mat_b.shape[1]

    m = mat_a.shape[0]
    n = mat_b.shape[0]

    mat_a_dots = (mat_a * mat_a).sum(axis=1).reshape(
        (m, 1)) * np.ones(shape=(1, n))
    mat_b_dots = (mat_b * mat_b).sum(axis=1) * np.ones(shape=(m, 1))
    mat_d_squared = mat_a_dots + mat_b_dots - 2 * mat_a.dot(mat_b.T)

    zero_mask = np.less(mat_d_squared, 0.0)
    mat_d_squared[zero_mask] = 0.0
    mat_d = np.sqrt(mat_d_squared)
    return mat_d


def feature_embedding(input_feats: ndarray, out_feat_len: int) -> ndarray:
    """Embed features. This code was partially adapted from
    https://github.com/GXYM/DRRG licensed under the MIT license.

    Args:
        input_feats (ndarray): The input features of shape (N, d), where N is
            the number of nodes in graph, d is the input feature vector length.
        out_feat_len (int): The length of output feature vector.

    Returns:
        ndarray: The embedded features.
    """
    assert input_feats.ndim == 2
    assert isinstance(out_feat_len, int)
    assert out_feat_len >= input_feats.shape[1]

    num_nodes = input_feats.shape[0]
    feat_dim = input_feats.shape[1]
    feat_repeat_times = out_feat_len // feat_dim
    residue_dim = out_feat_len % feat_dim

    if residue_dim > 0:
        embed_wave = np.array([
            np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
            for j in range(feat_repeat_times + 1)
        ]).reshape((feat_repeat_times + 1, 1, 1))
        repeat_feats = np.repeat(
            np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0)
        residue_feats = np.hstack([
            input_feats[:, 0:residue_dim],
            np.zeros((num_nodes, feat_dim - residue_dim))
        ])
        residue_feats = np.expand_dims(residue_feats, axis=0)
        repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
        embedded_feats = repeat_feats / embed_wave
        embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
        embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
        embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
            (num_nodes, -1))[:, 0:out_feat_len]
    else:
        embed_wave = np.array([
            np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
            for j in range(feat_repeat_times)
        ]).reshape((feat_repeat_times, 1, 1))
        repeat_feats = np.repeat(
            np.expand_dims(input_feats, axis=0), feat_repeat_times, axis=0)
        embedded_feats = repeat_feats / embed_wave
        embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
        embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
        embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
            (num_nodes, -1)).astype(np.float32)

    return embedded_feats


[docs]@MODELS.register_module() class DRRGHead(BaseTextDetHead): """The class for DRRG head: `Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection <https://arxiv.org/abs/2003.07493>`_. Args: in_channels (int): The number of input channels. k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. Defaults to (8, 4). num_adjacent_linkages (int): The number of linkages when constructing adjacent matrix. Defaults to 3. node_geo_feat_len (int): The length of embedded geometric feature vector of a component. Defaults to 120. pooling_scale (float): The spatial scale of rotated RoI-Align. Defaults to 1.0. pooling_output_size (tuple(int)): The output size of RRoI-Aligning. Defaults to (4, 3). nms_thr (float): The locality-aware NMS threshold of text components. Defaults to 0.3. min_width (float): The minimum width of text components. Defaults to 8.0. max_width (float): The maximum width of text components. Defaults to 24.0. comp_shrink_ratio (float): The shrink ratio of text components. Defaults to 1.03. comp_ratio (float): The reciprocal of aspect ratio of text components. Defaults to 0.4. comp_score_thr (float): The score threshold of text components. Defaults to 0.3. text_region_thr (float): The threshold for text region probability map. Defaults to 0.2. center_region_thr (float): The threshold for text center region probability map. Defaults to 0.2. center_region_area_thr (int): The threshold for filtering small-sized text center region. Defaults to 50. local_graph_thr (float): The threshold to filter identical local graphs. Defaults to 0.7. module_loss (dict): The config of loss that DRRGHead uses. Defaults to ``dict(type='DRRGModuleLoss')``. postprocessor (dict): Config of postprocessor for Drrg. Defaults to ``dict(type='DrrgPostProcessor', link_thr=0.85)``. init_cfg (dict or list[dict], optional): Initialization configs. Defaults to ``dict(type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)``. """ def __init__( self, in_channels: int, k_at_hops: Tuple[int, int] = (8, 4), num_adjacent_linkages: int = 3, node_geo_feat_len: int = 120, pooling_scale: float = 1.0, pooling_output_size: Tuple[int, int] = (4, 3), nms_thr: float = 0.3, min_width: float = 8.0, max_width: float = 24.0, comp_shrink_ratio: float = 1.03, comp_ratio: float = 0.4, comp_score_thr: float = 0.3, text_region_thr: float = 0.2, center_region_thr: float = 0.2, center_region_area_thr: int = 50, local_graph_thr: float = 0.7, module_loss: Dict = dict(type='DRRGModuleLoss'), postprocessor: Dict = dict(type='DRRGPostprocessor', link_thr=0.85), init_cfg: Optional[Union[Dict, List[Dict]]] = dict( type='Normal', override=dict(name='out_conv'), mean=0, std=0.01) ) -> None: super().__init__( module_loss=module_loss, postprocessor=postprocessor, init_cfg=init_cfg) assert isinstance(in_channels, int) assert isinstance(k_at_hops, tuple) assert isinstance(num_adjacent_linkages, int) assert isinstance(node_geo_feat_len, int) assert isinstance(pooling_scale, float) assert isinstance(pooling_output_size, tuple) assert isinstance(comp_shrink_ratio, float) assert isinstance(nms_thr, float) assert isinstance(min_width, float) assert isinstance(max_width, float) assert isinstance(comp_ratio, float) assert isinstance(comp_score_thr, float) assert isinstance(text_region_thr, float) assert isinstance(center_region_thr, float) assert isinstance(center_region_area_thr, int) assert isinstance(local_graph_thr, float) self.in_channels = in_channels self.out_channels = 6 self.downsample_ratio = 1.0 self.k_at_hops = k_at_hops self.num_adjacent_linkages = num_adjacent_linkages self.node_geo_feat_len = node_geo_feat_len self.pooling_scale = pooling_scale self.pooling_output_size = pooling_output_size self.comp_shrink_ratio = comp_shrink_ratio self.nms_thr = nms_thr self.min_width = min_width self.max_width = max_width self.comp_ratio = comp_ratio self.comp_score_thr = comp_score_thr self.text_region_thr = text_region_thr self.center_region_thr = center_region_thr self.center_region_area_thr = center_region_area_thr self.local_graph_thr = local_graph_thr self.out_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0) self.graph_train = LocalGraphs(self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len, self.pooling_scale, self.pooling_output_size, self.local_graph_thr) self.graph_test = ProposalLocalGraphs( self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len, self.pooling_scale, self.pooling_output_size, self.nms_thr, self.min_width, self.max_width, self.comp_shrink_ratio, self.comp_ratio, self.comp_score_thr, self.text_region_thr, self.center_region_thr, self.center_region_area_thr) pool_w, pool_h = self.pooling_output_size node_feat_len = (pool_w * pool_h) * ( self.in_channels + self.out_channels) + self.node_geo_feat_len self.gcn = GCN(node_feat_len)
[docs] def loss(self, inputs: torch.Tensor, data_samples: List[TextDetDataSample] ) -> Tuple[Tensor, Tensor, Tensor]: """Loss function. Args: inputs (Tensor): Shape of :math:`(N, C, H, W)`. data_samples (List[TextDetDataSample]): List of data samples. Returns: tuple(pred_maps, gcn_pred, gt_labels): - pred_maps (Tensor): Prediction map with shape :math:`(N, 6, H, W)`. - gcn_pred (Tensor): Prediction from GCN module, with shape :math:`(N, 2)`. - gt_labels (Tensor): Ground-truth label of shape :math:`(m, n)` where :math:`m * n = N`. """ targets = self.module_loss.get_targets(data_samples) gt_comp_attribs = targets[-1] pred_maps = self.out_conv(inputs) feat_maps = torch.cat([inputs, pred_maps], dim=1) node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train( feat_maps, np.stack(gt_comp_attribs)) gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds) return self.module_loss((pred_maps, gcn_pred, gt_labels), data_samples)
[docs] def forward( self, inputs: Tensor, data_samples: Optional[List[TextDetDataSample]] = None ) -> Tuple[Tensor, Tensor, Tensor]: r"""Run DRRG head in prediction mode, and return the raw tensors only. Args: inputs (Tensor): Shape of :math:`(1, C, H, W)`. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. Returns: tuple: Returns (edge, score, text_comps). - edge (ndarray): The edge array of shape :math:`(N_{edges}, 2)` where each row is a pair of text component indices that makes up an edge in graph. - score (ndarray): The score array of shape :math:`(N_{edges},)`, corresponding to the edge above. - text_comps (ndarray): The text components of shape :math:`(M, 9)` where each row corresponds to one box and its score: (x1, y1, x2, y2, x3, y3, x4, y4, score). """ pred_maps = self.out_conv(inputs) inputs = torch.cat([inputs, pred_maps], dim=1) none_flag, graph_data = self.graph_test(pred_maps, inputs) (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivot_local_graphs, text_comps) = graph_data if none_flag: return None, None, None gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds) pred_labels = F.softmax(gcn_pred, dim=1) edges = [] scores = [] pivot_local_graphs = pivot_local_graphs.long().squeeze().cpu().numpy() for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs): pivot = pivot_local_graph[0] for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]): neighbor = pivot_local_graph[neighbor_ind.item()] edges.append([pivot, neighbor]) scores.append( pred_labels[pivot_ind * pivots_knn_inds.shape[1] + k_ind, 1].item()) edges = np.asarray(edges) scores = np.asarray(scores) return edges, scores, text_comps
class LocalGraphs: """Generate local graphs for GCN to classify the neighbors of a pivot for `DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection <[https://arxiv.org/abs/2003.07493]>`_. This code was partially adapted from https://github.com/GXYM/DRRG licensed under the MIT license. Args: k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. num_adjacent_linkages (int): The number of linkages when constructing adjacent matrix. node_geo_feat_len (int): The length of embedded geometric feature vector of a text component. pooling_scale (float): The spatial scale of rotated RoI-Align. pooling_output_size (tuple(int)): The output size of rotated RoI-Align. local_graph_thr(float): The threshold for filtering out identical local graphs. """ def __init__(self, k_at_hops: Tuple[int, int], num_adjacent_linkages: int, node_geo_feat_len: int, pooling_scale: float, pooling_output_size: Sequence[int], local_graph_thr: float) -> None: assert len(k_at_hops) == 2 assert all(isinstance(n, int) for n in k_at_hops) assert isinstance(num_adjacent_linkages, int) assert isinstance(node_geo_feat_len, int) assert isinstance(pooling_scale, float) assert all(isinstance(n, int) for n in pooling_output_size) assert isinstance(local_graph_thr, float) self.k_at_hops = k_at_hops self.num_adjacent_linkages = num_adjacent_linkages self.node_geo_feat_dim = node_geo_feat_len self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale) self.local_graph_thr = local_graph_thr def generate_local_graphs(self, sorted_dist_inds: ndarray, gt_comp_labels: ndarray ) -> Tuple[List[List[int]], List[List[int]]]: """Generate local graphs for GCN to predict which instance a text component belongs to. Args: sorted_dist_inds (ndarray): The complete graph node indices, which is sorted according to the Euclidean distance. gt_comp_labels(ndarray): The ground truth labels define the instance to which the text components (nodes in graphs) belong. Returns: Tuple(pivot_local_graphs, pivot_knns): - pivot_local_graphs (list[list[int]]): The list of local graph neighbor indices of pivots. - pivot_knns (list[list[int]]): The list of k-nearest neighbor indices of pivots. """ assert sorted_dist_inds.ndim == 2 assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] == gt_comp_labels.shape[0]) knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1] pivot_local_graphs = [] pivot_knns = [] for pivot_ind, knn in enumerate(knn_graph): local_graph_neighbors = set(knn) for neighbor_ind in knn: local_graph_neighbors.update( set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] + 1])) local_graph_neighbors.discard(pivot_ind) pivot_local_graph = list(local_graph_neighbors) pivot_local_graph.insert(0, pivot_ind) pivot_knn = [pivot_ind] + list(knn) if pivot_ind < 1: pivot_local_graphs.append(pivot_local_graph) pivot_knns.append(pivot_knn) else: add_flag = True for graph_ind, added_knn in enumerate(pivot_knns): added_pivot_ind = added_knn[0] added_local_graph = pivot_local_graphs[graph_ind] union = len( set(pivot_local_graph[1:]).union( set(added_local_graph[1:]))) intersect = len( set(pivot_local_graph[1:]).intersection( set(added_local_graph[1:]))) local_graph_iou = intersect / (union + 1e-8) if (local_graph_iou > self.local_graph_thr and pivot_ind in added_knn and gt_comp_labels[added_pivot_ind] == gt_comp_labels[pivot_ind] and gt_comp_labels[pivot_ind] != 0): add_flag = False break if add_flag: pivot_local_graphs.append(pivot_local_graph) pivot_knns.append(pivot_knn) return pivot_local_graphs, pivot_knns def generate_gcn_input( self, node_feat_batch: List[Tensor], node_label_batch: List[ndarray], local_graph_batch: List[List[List[int]]], knn_batch: List[List[List[int]]], sorted_dist_ind_batch: List[ndarray] ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Generate graph convolution network input data. Args: node_feat_batch (List[Tensor]): The batched graph node features. node_label_batch (List[ndarray]): The batched text component labels. local_graph_batch (List[List[List[int]]]): The local graph node indices of image batch. knn_batch (List[List[List[int]]]): The knn graph node indices of image batch. sorted_dist_ind_batch (List[ndarray]): The node indices sorted according to the Euclidean distance. Returns: Tuple(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, gt_linkage): - local_graphs_node_feat (Tensor): The node features of graph. - adjacent_matrices (Tensor): The adjacent matrices of local graphs. - pivots_knn_inds (Tensor): The k-nearest neighbor indices in local graph. - gt_linkage (Tensor): The surpervision signal of GCN for linkage prediction. """ assert isinstance(node_feat_batch, list) assert isinstance(node_label_batch, list) assert isinstance(local_graph_batch, list) assert isinstance(knn_batch, list) assert isinstance(sorted_dist_ind_batch, list) num_max_nodes = max( len(pivot_local_graph) for pivot_local_graphs in local_graph_batch for pivot_local_graph in pivot_local_graphs) local_graphs_node_feat = [] adjacent_matrices = [] pivots_knn_inds = [] pivots_gt_linkage = [] for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch): node_feats = node_feat_batch[batch_ind] pivot_local_graphs = local_graph_batch[batch_ind] pivot_knns = knn_batch[batch_ind] node_labels = node_label_batch[batch_ind] device = node_feats.device for graph_ind, pivot_knn in enumerate(pivot_knns): pivot_local_graph = pivot_local_graphs[graph_ind] num_nodes = len(pivot_local_graph) pivot_ind = pivot_local_graph[0] node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)} knn_inds = torch.tensor( [node2ind_map[i] for i in pivot_knn[1:]]) pivot_feats = node_feats[pivot_ind] normalized_feats = node_feats[pivot_local_graph] - pivot_feats adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32) for node in pivot_local_graph: neighbors = sorted_dist_inds[node, 1:self.num_adjacent_linkages + 1] for neighbor in neighbors: if neighbor in pivot_local_graph: adjacent_matrix[node2ind_map[node], node2ind_map[neighbor]] = 1 adjacent_matrix[node2ind_map[neighbor], node2ind_map[node]] = 1 adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix) pad_adjacent_matrix = torch.zeros( (num_max_nodes, num_max_nodes), dtype=torch.float, device=device) pad_adjacent_matrix[:num_nodes, :num_nodes] = torch.from_numpy( adjacent_matrix) pad_normalized_feats = torch.cat([ normalized_feats, torch.zeros( (num_max_nodes - num_nodes, normalized_feats.shape[1]), dtype=torch.float, device=device) ], dim=0) local_graph_labels = node_labels[pivot_local_graph] knn_labels = local_graph_labels[knn_inds] link_labels = ((node_labels[pivot_ind] == knn_labels) & (node_labels[pivot_ind] > 0)).astype(np.int64) link_labels = torch.from_numpy(link_labels) local_graphs_node_feat.append(pad_normalized_feats) adjacent_matrices.append(pad_adjacent_matrix) pivots_knn_inds.append(knn_inds) pivots_gt_linkage.append(link_labels) local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0) adjacent_matrices = torch.stack(adjacent_matrices, 0) pivots_knn_inds = torch.stack(pivots_knn_inds, 0) pivots_gt_linkage = torch.stack(pivots_gt_linkage, 0) return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivots_gt_linkage) def __call__(self, feat_maps: Tensor, comp_attribs: ndarray ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Generate local graphs as GCN input. Args: feat_maps (Tensor): The feature maps to extract the content features of text components. comp_attribs (ndarray): The text component attributes. Returns: Tuple(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, gt_linkage): - local_graphs_node_feat (Tensor): The node features of graph. - adjacent_matrices (Tensor): The adjacent matrices of local graphs. - pivots_knn_inds (Tensor): The k-nearest neighbor indices in local graph. - gt_linkage (Tensor): The surpervision signal of GCN for linkage prediction. """ assert isinstance(feat_maps, Tensor) assert comp_attribs.ndim == 3 assert comp_attribs.shape[2] == 8 sorted_dist_inds_batch = [] local_graph_batch = [] knn_batch = [] node_feat_batch = [] node_label_batch = [] device = feat_maps.device for batch_ind in range(comp_attribs.shape[0]): num_comps = int(comp_attribs[batch_ind, 0, 0]) comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7] node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(np.int32) comp_centers = comp_geo_attribs[:, 0:2] distance_matrix = euclidean_distance_matrix( comp_centers, comp_centers) batch_id = np.zeros( (comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1) angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign( comp_geo_attribs[:, -1]) angle = angle.reshape((-1, 1)) rotated_rois = np.hstack( [batch_id, comp_geo_attribs[:, :-2], angle]) rois = torch.from_numpy(rotated_rois).to(device) content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0), rois) content_feats = content_feats.view(content_feats.shape[0], -1).to(feat_maps.device) geo_feats = feature_embedding(comp_geo_attribs, self.node_geo_feat_dim) geo_feats = torch.from_numpy(geo_feats).to(device) node_feats = torch.cat([content_feats, geo_feats], dim=-1) sorted_dist_inds = np.argsort(distance_matrix, axis=1) pivot_local_graphs, pivot_knns = self.generate_local_graphs( sorted_dist_inds, node_labels) node_feat_batch.append(node_feats) node_label_batch.append(node_labels) local_graph_batch.append(pivot_local_graphs) knn_batch.append(pivot_knns) sorted_dist_inds_batch.append(sorted_dist_inds) (node_feats, adjacent_matrices, knn_inds, gt_linkage) = \ self.generate_gcn_input(node_feat_batch, node_label_batch, local_graph_batch, knn_batch, sorted_dist_inds_batch) return node_feats, adjacent_matrices, knn_inds, gt_linkage class ProposalLocalGraphs: """Propose text components and generate local graphs for GCN to classify the k-nearest neighbors of a pivot in `DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text Detection. <https://arxiv.org/abs/2003.07493>`_. This code was partially adapted from https://github.com/GXYM/DRRG licensed under the MIT license. Args: k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. num_adjacent_linkages (int): The number of linkages when constructing adjacent matrix. node_geo_feat_len (int): The length of embedded geometric feature vector of a text component. pooling_scale (float): The spatial scale of rotated RoI-Align. pooling_output_size (tuple(int)): The output size of rotated RoI-Align. nms_thr (float): The locality-aware NMS threshold for text components. min_width (float): The minimum width of text components. max_width (float): The maximum width of text components. comp_shrink_ratio (float): The shrink ratio of text components. comp_w_h_ratio (float): The width to height ratio of text components. comp_score_thr (float): The score threshold of text component. text_region_thr (float): The threshold for text region probability map. center_region_thr (float): The threshold for text center region probability map. center_region_area_thr (int): The threshold for filtering small-sized text center region. """ def __init__(self, k_at_hops: Tuple[int, int], num_adjacent_linkages: int, node_geo_feat_len: int, pooling_scale: float, pooling_output_size: Sequence[int], nms_thr: float, min_width: float, max_width: float, comp_shrink_ratio: float, comp_w_h_ratio: float, comp_score_thr: float, text_region_thr: float, center_region_thr: float, center_region_area_thr: int) -> None: assert len(k_at_hops) == 2 assert isinstance(k_at_hops, tuple) assert isinstance(num_adjacent_linkages, int) assert isinstance(node_geo_feat_len, int) assert isinstance(pooling_scale, float) assert isinstance(pooling_output_size, tuple) assert isinstance(nms_thr, float) assert isinstance(min_width, float) assert isinstance(max_width, float) assert isinstance(comp_shrink_ratio, float) assert isinstance(comp_w_h_ratio, float) assert isinstance(comp_score_thr, float) assert isinstance(text_region_thr, float) assert isinstance(center_region_thr, float) assert isinstance(center_region_area_thr, int) self.k_at_hops = k_at_hops self.active_connection = num_adjacent_linkages self.local_graph_depth = len(self.k_at_hops) self.node_geo_feat_dim = node_geo_feat_len self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale) self.nms_thr = nms_thr self.min_width = min_width self.max_width = max_width self.comp_shrink_ratio = comp_shrink_ratio self.comp_w_h_ratio = comp_w_h_ratio self.comp_score_thr = comp_score_thr self.text_region_thr = text_region_thr self.center_region_thr = center_region_thr self.center_region_area_thr = center_region_area_thr def propose_comps(self, score_map: ndarray, top_height_map: ndarray, bot_height_map: ndarray, sin_map: ndarray, cos_map: ndarray, comp_score_thr: float, min_width: float, max_width: float, comp_shrink_ratio: float, comp_w_h_ratio: float) -> ndarray: """Propose text components. Args: score_map (ndarray): The score map for NMS. top_height_map (ndarray): The predicted text height map from each pixel in text center region to top sideline. bot_height_map (ndarray): The predicted text height map from each pixel in text center region to bottom sideline. sin_map (ndarray): The predicted sin(theta) map. cos_map (ndarray): The predicted cos(theta) map. comp_score_thr (float): The score threshold of text component. min_width (float): The minimum width of text components. max_width (float): The maximum width of text components. comp_shrink_ratio (float): The shrink ratio of text components. comp_w_h_ratio (float): The width to height ratio of text components. Returns: ndarray: The text components. """ comp_centers = np.argwhere(score_map > comp_score_thr) comp_centers = comp_centers[np.argsort(comp_centers[:, 0])] y = comp_centers[:, 0] x = comp_centers[:, 1] top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio sin = sin_map[y, x].reshape((-1, 1)) cos = cos_map[y, x].reshape((-1, 1)) top_mid_pts = comp_centers + np.hstack( [top_height * sin, top_height * cos]) bot_mid_pts = comp_centers - np.hstack( [bot_height * sin, bot_height * cos]) width = (top_height + bot_height) * comp_w_h_ratio width = np.clip(width, min_width, max_width) r = width / 2 tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos]) tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos]) br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos]) bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos]) text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32) score = score_map[y, x].reshape((-1, 1)) text_comps = np.hstack([text_comps, score]) return text_comps def propose_comps_and_attribs(self, text_region_map: ndarray, center_region_map: ndarray, top_height_map: ndarray, bot_height_map: ndarray, sin_map: ndarray, cos_map: ndarray) -> Tuple[ndarray, ndarray]: """Generate text components and attributes. Args: text_region_map (ndarray): The predicted text region probability map. center_region_map (ndarray): The predicted text center region probability map. top_height_map (ndarray): The predicted text height map from each pixel in text center region to top sideline. bot_height_map (ndarray): The predicted text height map from each pixel in text center region to bottom sideline. sin_map (ndarray): The predicted sin(theta) map. cos_map (ndarray): The predicted cos(theta) map. Returns: tuple(ndarray, ndarray): - comp_attribs (ndarray): The text component attributes. - text_comps (ndarray): The text components. """ assert (text_region_map.shape == center_region_map.shape == top_height_map.shape == bot_height_map.shape == sin_map.shape == cos_map.shape) text_mask = text_region_map > self.text_region_thr center_region_mask = (center_region_map > self.center_region_thr) * text_mask scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8)) sin_map, cos_map = sin_map * scale, cos_map * scale center_region_mask = fill_hole(center_region_mask) center_region_contours, _ = cv2.findContours( center_region_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) mask_sz = center_region_map.shape comp_list = [] for contour in center_region_contours: current_center_mask = np.zeros(mask_sz) cv2.drawContours(current_center_mask, [contour], -1, 1, -1) if current_center_mask.sum() <= self.center_region_area_thr: continue score_map = text_region_map * current_center_mask text_comps = self.propose_comps(score_map, top_height_map, bot_height_map, sin_map, cos_map, self.comp_score_thr, self.min_width, self.max_width, self.comp_shrink_ratio, self.comp_w_h_ratio) if la_nms is None: raise ImportError('lanms-neo is not installed, ' 'please run "pip install lanms-neo==1.0.2".') text_comps = la_nms(text_comps, self.nms_thr) text_comp_mask = np.zeros(mask_sz) text_comp_boxes = text_comps[:, :8].reshape( (-1, 4, 2)).astype(np.int32) cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1) if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5: continue if text_comps.shape[-1] > 0: comp_list.append(text_comps) if len(comp_list) <= 0: return None, None text_comps = np.vstack(comp_list) text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2)) centers = np.mean(text_comp_boxes, axis=1).astype(np.int32) x = centers[:, 0] y = centers[:, 1] scores = [] for text_comp_box in text_comp_boxes: text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0, mask_sz[1] - 1) text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0, mask_sz[0] - 1) min_coord = np.min(text_comp_box, axis=0).astype(np.int32) max_coord = np.max(text_comp_box, axis=0).astype(np.int32) text_comp_box = text_comp_box - min_coord box_sz = (max_coord - min_coord + 1) temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8) cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1) temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] + 1), min_coord[0]:(max_coord[0] + 1)] score = cv2.mean(temp_region_patch, temp_comp_mask)[0] scores.append(score) scores = np.array(scores).reshape((-1, 1)) text_comps = np.hstack([text_comps[:, :-1], scores]) h = top_height_map[y, x].reshape( (-1, 1)) + bot_height_map[y, x].reshape((-1, 1)) w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width) sin = sin_map[y, x].reshape((-1, 1)) cos = cos_map[y, x].reshape((-1, 1)) x = x.reshape((-1, 1)) y = y.reshape((-1, 1)) comp_attribs = np.hstack([x, y, h, w, cos, sin]) return comp_attribs, text_comps def generate_local_graphs(self, sorted_dist_inds: ndarray, node_feats: Tensor ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Generate local graphs and graph convolution network input data. Args: sorted_dist_inds (ndarray): The node indices sorted according to the Euclidean distance. node_feats (tensor): The features of nodes in graph. Returns: Tuple(local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivots_local_graphs): - local_graphs_node_feats (tensor): The features of nodes in local graphs. - adjacent_matrices (tensor): The adjacent matrices. - pivots_knn_inds (tensor): The k-nearest neighbor indices in local graphs. - pivots_local_graphs (tensor): The indices of nodes in local graphs. """ assert sorted_dist_inds.ndim == 2 assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] == node_feats.shape[0]) knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1] pivot_local_graphs = [] pivot_knns = [] device = node_feats.device for pivot_ind, knn in enumerate(knn_graph): local_graph_neighbors = set(knn) for neighbor_ind in knn: local_graph_neighbors.update( set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] + 1])) local_graph_neighbors.discard(pivot_ind) pivot_local_graph = list(local_graph_neighbors) pivot_local_graph.insert(0, pivot_ind) pivot_knn = [pivot_ind] + list(knn) pivot_local_graphs.append(pivot_local_graph) pivot_knns.append(pivot_knn) num_max_nodes = max( len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs) local_graphs_node_feat = [] adjacent_matrices = [] pivots_knn_inds = [] pivots_local_graphs = [] for graph_ind, pivot_knn in enumerate(pivot_knns): pivot_local_graph = pivot_local_graphs[graph_ind] num_nodes = len(pivot_local_graph) pivot_ind = pivot_local_graph[0] node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)} knn_inds = torch.tensor([node2ind_map[i] for i in pivot_knn[1:]]).long().to(device) pivot_feats = node_feats[pivot_ind] normalized_feats = node_feats[pivot_local_graph] - pivot_feats adjacent_matrix = np.zeros((num_nodes, num_nodes)) for node in pivot_local_graph: neighbors = sorted_dist_inds[node, 1:self.active_connection + 1] for neighbor in neighbors: if neighbor in pivot_local_graph: adjacent_matrix[node2ind_map[node], node2ind_map[neighbor]] = 1 adjacent_matrix[node2ind_map[neighbor], node2ind_map[node]] = 1 adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix) pad_adjacent_matrix = torch.zeros((num_max_nodes, num_max_nodes), dtype=torch.float, device=device) pad_adjacent_matrix[:num_nodes, :num_nodes] = torch.from_numpy( adjacent_matrix) pad_normalized_feats = torch.cat([ normalized_feats, torch.zeros( (num_max_nodes - num_nodes, normalized_feats.shape[1]), dtype=torch.float, device=device) ], dim=0) local_graph_nodes = torch.tensor(pivot_local_graph) local_graph_nodes = torch.cat([ local_graph_nodes, torch.zeros(num_max_nodes - num_nodes, dtype=torch.long) ], dim=-1) local_graphs_node_feat.append(pad_normalized_feats) adjacent_matrices.append(pad_adjacent_matrix) pivots_knn_inds.append(knn_inds) pivots_local_graphs.append(local_graph_nodes) local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0) adjacent_matrices = torch.stack(adjacent_matrices, 0) pivots_knn_inds = torch.stack(pivots_knn_inds, 0) pivots_local_graphs = torch.stack(pivots_local_graphs, 0) return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivots_local_graphs) def __call__(self, preds: Tensor, feat_maps: Tensor ) -> Tuple[bool, Tensor, Tensor, Tensor, Tensor, ndarray]: """Generate local graphs and graph convolutional network input data. Args: preds (tensor): The predicted maps. feat_maps (tensor): The feature maps to extract content feature of text components. Returns: Tuple(none_flag, local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivots_local_graphs, text_comps): - none_flag (bool): The flag showing whether the number of proposed text components is 0. - local_graphs_node_feats (tensor): The features of nodes in local graphs. - adjacent_matrices (tensor): The adjacent matrices. - pivots_knn_inds (tensor): The k-nearest neighbor indices in local graphs. - pivots_local_graphs (tensor): The indices of nodes in local graphs. - text_comps (ndarray): The predicted text components. """ if preds.ndim == 4: assert preds.shape[0] == 1 preds = torch.squeeze(preds) pred_text_region = torch.sigmoid(preds[0]).data.cpu().numpy() pred_center_region = torch.sigmoid(preds[1]).data.cpu().numpy() pred_sin_map = preds[2].data.cpu().numpy() pred_cos_map = preds[3].data.cpu().numpy() pred_top_height_map = preds[4].data.cpu().numpy() pred_bot_height_map = preds[5].data.cpu().numpy() device = preds.device comp_attribs, text_comps = self.propose_comps_and_attribs( pred_text_region, pred_center_region, pred_top_height_map, pred_bot_height_map, pred_sin_map, pred_cos_map) if comp_attribs is None or len(comp_attribs) < 2: none_flag = True return none_flag, (0, 0, 0, 0, 0) comp_centers = comp_attribs[:, 0:2] distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers) geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim) geo_feats = torch.from_numpy(geo_feats).to(preds.device) batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32) comp_attribs = comp_attribs.astype(np.float32) angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1]) angle = angle.reshape((-1, 1)) rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle]) rois = torch.from_numpy(rotated_rois).to(device) content_feats = self.pooling(feat_maps, rois) content_feats = content_feats.view(content_feats.shape[0], -1).to(device) node_feats = torch.cat([content_feats, geo_feats], dim=-1) sorted_dist_inds = np.argsort(distance_matrix, axis=1) (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivots_local_graphs) = self.generate_local_graphs( sorted_dist_inds, node_feats) none_flag = False return none_flag, (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, pivots_local_graphs, text_comps) class GraphConv(BaseModule): """Graph convolutional neural network. Args: in_dim (int): The number of input channels. out_dim (int): The number of output channels. """ class MeanAggregator(BaseModule): """Mean aggregator for graph convolutional network.""" def forward(self, features: Tensor, A: Tensor) -> Tensor: """Forward function.""" x = torch.bmm(A, features) return x def __init__(self, in_dim: int, out_dim: int) -> None: super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim)) self.bias = nn.Parameter(torch.FloatTensor(out_dim)) init.xavier_uniform_(self.weight) init.constant_(self.bias, 0) self.aggregator = self.MeanAggregator() def forward(self, features: Tensor, A: Tensor) -> Tensor: """Forward function.""" _, _, d = features.shape assert d == self.in_dim agg_feats = self.aggregator(features, A) cat_feats = torch.cat([features, agg_feats], dim=2) out = torch.einsum('bnd,df->bnf', cat_feats, self.weight) out = F.relu(out + self.bias) return out class GCN(BaseModule): """Graph convolutional network for clustering. This was from repo https://github.com/Zhongdao/gcn_clustering licensed under the MIT license. Args: feat_len (int): The input node feature length. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, feat_len: int, init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: super().__init__(init_cfg=init_cfg) self.bn0 = nn.BatchNorm1d(feat_len, affine=False).float() self.conv1 = GraphConv(feat_len, 512) self.conv2 = GraphConv(512, 256) self.conv3 = GraphConv(256, 128) self.conv4 = GraphConv(128, 64) self.classifier = nn.Sequential( nn.Linear(64, 32), nn.PReLU(32), nn.Linear(32, 2)) def forward(self, node_feats: Tensor, adj_mats: Tensor, knn_inds: Tensor) -> Tensor: """Forward function. Args: local_graphs_node_feat (Tensor): The node features of graph. adjacent_matrices (Tensor): The adjacent matrices of local graphs. pivots_knn_inds (Tensor): The k-nearest neighbor indices in local graph. Returns: Tensor: The output feature. """ num_local_graphs, num_max_nodes, feat_len = node_feats.shape node_feats = node_feats.view(-1, feat_len) node_feats = self.bn0(node_feats) node_feats = node_feats.view(num_local_graphs, num_max_nodes, feat_len) node_feats = self.conv1(node_feats, adj_mats) node_feats = self.conv2(node_feats, adj_mats) node_feats = self.conv3(node_feats, adj_mats) node_feats = self.conv4(node_feats, adj_mats) k = knn_inds.size(-1) mid_feat_len = node_feats.size(-1) edge_feat = torch.zeros((num_local_graphs, k, mid_feat_len), device=node_feats.device) for graph_ind in range(num_local_graphs): edge_feat[graph_ind, :, :] = node_feats[graph_ind, knn_inds[graph_ind]] edge_feat = edge_feat.view(-1, mid_feat_len) pred = self.classifier(edge_feat) return pred
Read the Docs v: dev-1.x
Versions
latest
stable
v1.0.1
v1.0.0
0.x
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.