Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.datasets.pipelines.textdet_targets.dbnet_targets

# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
import pyclipper
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from shapely.geometry import Polygon

from . import BaseTextDetTargets


[docs]@PIPELINES.register_module() class DBNetTargets(BaseTextDetTargets): """Generate gt shrunk text, gt threshold map, and their effective region masks to learn DBNet: Real-time Scene Text Detection with Differentiable Binarization [https://arxiv.org/abs/1911.08947]. This was partially adapted from https://github.com/MhLiao/DB. Args: shrink_ratio (float): The area shrunk ratio between text kernels and their text masks. thr_min (float): The minimum value of the threshold map. thr_max (float): The maximum value of the threshold map. min_short_size (int): The minimum size of polygon below which the polygon is invalid. """ def __init__(self, shrink_ratio=0.4, thr_min=0.3, thr_max=0.7, min_short_size=8): super().__init__() self.shrink_ratio = shrink_ratio self.thr_min = thr_min self.thr_max = thr_max self.min_short_size = min_short_size
[docs] def find_invalid(self, results): """Find invalid polygons. Args: results (dict): The dict containing gt_mask. Returns: ignore_tags (list[bool]): The indicators for ignoring polygons. """ texts = results['gt_masks'].masks ignore_tags = [False] * len(texts) for idx, text in enumerate(texts): if self.invalid_polygon(text[0]): ignore_tags[idx] = True return ignore_tags
[docs] def invalid_polygon(self, poly): """Judge the input polygon is invalid or not. It is invalid if its area smaller than 1 or the shorter side of its minimum bounding box smaller than min_short_size. Args: poly (ndarray): The polygon boundary point sequence. Returns: True/False (bool): Whether the polygon is invalid. """ area = self.polygon_area(poly) if abs(area) < 1: return True short_size = min(self.polygon_size(poly)) if short_size < self.min_short_size: return True return False
[docs] def ignore_texts(self, results, ignore_tags): """Ignore gt masks and gt_labels while padding gt_masks_ignore in results given ignore_tags. Args: results (dict): Result for one image. ignore_tags (list[int]): Indicate whether to ignore its corresponding ground truth text. Returns: results (dict): Results after filtering. """ flag_len = len(ignore_tags) assert flag_len == len(results['gt_masks'].masks) assert flag_len == len(results['gt_labels']) results['gt_masks_ignore'].masks += [ mask for i, mask in enumerate(results['gt_masks'].masks) if ignore_tags[i] ] results['gt_masks'].masks = [ mask for i, mask in enumerate(results['gt_masks'].masks) if not ignore_tags[i] ] results['gt_labels'] = np.array([ mask for i, mask in enumerate(results['gt_labels']) if not ignore_tags[i] ]) new_ignore_tags = [ignore for ignore in ignore_tags if not ignore] return results, new_ignore_tags
[docs] def generate_thr_map(self, img_size, polygons): """Generate threshold map. Args: img_size (tuple(int)): The image size (h,w) polygons (list(ndarray)): The polygon list. Returns: thr_map (ndarray): The generated threshold map. thr_mask (ndarray): The effective mask of threshold map. """ thr_map = np.zeros(img_size, dtype=np.float32) thr_mask = np.zeros(img_size, dtype=np.uint8) for polygon in polygons: self.draw_border_map(polygon[0], thr_map, mask=thr_mask) thr_map = thr_map * (self.thr_max - self.thr_min) + self.thr_min return thr_map, thr_mask
[docs] def draw_border_map(self, polygon, canvas, mask): """Generate threshold map for one polygon. Args: polygon(ndarray): The polygon boundary ndarray. canvas(ndarray): The generated threshold map. mask(ndarray): The generated threshold mask. """ polygon = polygon.reshape(-1, 2) assert polygon.ndim == 2 assert polygon.shape[1] == 2 polygon_shape = Polygon(polygon) distance = ( polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length) subject = [tuple(p) for p in polygon] padding = pyclipper.PyclipperOffset() padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) padded_polygon = padding.Execute(distance) if len(padded_polygon) > 0: padded_polygon = np.array(padded_polygon[0]) else: print(f'padding {polygon} with {distance} gets {padded_polygon}') padded_polygon = polygon.copy().astype(np.int32) x_min = padded_polygon[:, 0].min() x_max = padded_polygon[:, 0].max() y_min = padded_polygon[:, 1].min() y_max = padded_polygon[:, 1].max() width = x_max - x_min + 1 height = y_max - y_min + 1 polygon[:, 0] = polygon[:, 0] - x_min polygon[:, 1] = polygon[:, 1] - y_min xs = np.broadcast_to( np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) ys = np.broadcast_to( np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32) for i in range(polygon.shape[0]): j = (i + 1) % polygon.shape[0] absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j]) distance_map[i] = np.clip(absolute_distance / distance, 0, 1) distance_map = distance_map.min(axis=0) x_min_valid = min(max(0, x_min), canvas.shape[1] - 1) x_max_valid = min(max(0, x_max), canvas.shape[1] - 1) y_min_valid = min(max(0, y_min), canvas.shape[0] - 1) y_max_valid = min(max(0, y_max), canvas.shape[0] - 1) if x_min_valid - x_min >= width or y_min_valid - y_min >= height: return cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) canvas[y_min_valid:y_max_valid + 1, x_min_valid:x_max_valid + 1] = np.fmax( 1 - distance_map[y_min_valid - y_min:y_max_valid - y_max + height, x_min_valid - x_min:x_max_valid - x_max + width], canvas[y_min_valid:y_max_valid + 1, x_min_valid:x_max_valid + 1])
[docs] def generate_targets(self, results): """Generate the gt targets for DBNet. Args: results (dict): The input result dictionary. Returns: results (dict): The output result dictionary. """ assert isinstance(results, dict) if 'bbox_fields' in results: results['bbox_fields'].clear() ignore_tags = self.find_invalid(results) results, ignore_tags = self.ignore_texts(results, ignore_tags) h, w, _ = results['img_shape'] polygons = results['gt_masks'].masks # generate gt_shrink_kernel gt_shrink, ignore_tags = self.generate_kernels((h, w), polygons, self.shrink_ratio, ignore_tags=ignore_tags) results, ignore_tags = self.ignore_texts(results, ignore_tags) # genenrate gt_shrink_mask polygons_ignore = results['gt_masks_ignore'].masks gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore) # generate gt_threshold and gt_threshold_mask polygons = results['gt_masks'].masks gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons) results['mask_fields'].clear() # rm gt_masks encoded by polygons results.pop('gt_labels', None) results.pop('gt_masks', None) results.pop('gt_bboxes', None) results.pop('gt_bboxes_ignore', None) mapping = { 'gt_shrink': gt_shrink, 'gt_shrink_mask': gt_shrink_mask, 'gt_thr': gt_thr, 'gt_thr_mask': gt_thr_mask } for key, value in mapping.items(): value = value if isinstance(value, list) else [value] results[key] = BitmapMasks(value, h, w) results['mask_fields'].append(key) return results
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.