
mmocr.models.textdet.postprocessors.pse_postprocessor 源代码

# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import cv2
import numpy as np
import torch
from mmcv.ops import contour_expand
from mmengine.structures import InstanceData

from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .pan_postprocessor import PANPostprocessor

[文档]@MODELS.register_module() class PSEPostprocessor(PANPostprocessor): """Decoding predictions of PSENet to instances. This is partially adapted from Args: text_repr_type (str): The boundary encoding type 'poly' or 'quad'. Defaults to 'poly'. rescale_fields (list[str]): The bbox/polygon field names to be rescaled. If None, no rescaling will be performed. Defaults to ['polygons']. min_kernel_confidence (float): The minimal kernel confidence. Defaults to 0.5. score_threshold (float): The minimal text average confidence. Defaults to 0.3. min_kernel_area (int): The minimal text kernel area. Defaults to 0. min_text_area (int): The minimal text instance region area. Defaults to 16. downsample_ratio (float): Downsample ratio. Defaults to 0.25. """ def __init__(self, text_repr_type: str = 'poly', rescale_fields: List[str] = ['polygons'], min_kernel_confidence: float = 0.5, score_threshold: float = 0.3, min_kernel_area: int = 0, min_text_area: int = 16, downsample_ratio: float = 0.25) -> None: super().__init__( text_repr_type=text_repr_type, rescale_fields=rescale_fields, min_kernel_confidence=min_kernel_confidence, score_threshold=score_threshold, min_text_area=min_text_area, downsample_ratio=downsample_ratio) self.min_kernel_area = min_kernel_area
[文档] def get_text_instances(self, pred_results: torch.Tensor, data_sample: TextDetDataSample, **kwargs) -> TextDetDataSample: """ Args: pred_result (torch.Tensor): Prediction results of an image which is a tensor of shape :math:`(N, H, W)`. data_sample (TextDetDataSample): Datasample of an image. Returns: TextDetDataSample: A new DataSample with predictions filled in. Polygons and results are saved in ``TextDetDataSample.pred_instances.polygons``. The confidence scores are saved in ``TextDetDataSample.pred_instances.scores``. """ assert pred_results.dim() == 3 pred_results = torch.sigmoid(pred_results) # text confidence masks = pred_results > self.min_kernel_confidence text_mask = masks[0, :, :] kernel_masks = masks[0:, :, :] * text_mask kernel_masks = score = pred_results[0, :, :] score = region_num, labels = cv2.connectedComponents( kernel_masks[-1], connectivity=4) labels = contour_expand(kernel_masks, labels, self.min_kernel_area, region_num) labels = np.array(labels) label_num = np.max(labels) polygons = [] scores = [] for i in range(1, label_num + 1): points = np.array(np.where(labels == i)).transpose((1, 0))[:, ::-1] area = points.shape[0] score_instance = np.mean(score[labels == i]) if not (area >= self.min_text_area or score_instance > self.score_threshold): continue polygon = self._points2boundary(points) if polygon: polygons.append(polygon) scores.append(score_instance) pred_instances = InstanceData() pred_instances.polygons = polygons pred_instances.scores = torch.FloatTensor(scores) data_sample.pred_instances = pred_instances scale_factor = data_sample.scale_factor scale_factor = tuple(factor * self.downsample_ratio for factor in scale_factor) data_sample.set_metainfo(dict(scale_factor=scale_factor)) return data_sample
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.