Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.datasets.pipelines.dbnet_transforms
# Copyright (c) OpenMMLab. All rights reserved.
import imgaug
import imgaug.augmenters as iaa
import mmcv
import numpy as np
from mmdet.core.mask import PolygonMasks
from mmdet.datasets.builder import PIPELINES
class AugmenterBuilder:
"""Build imgaug object according ImgAug argmentations."""
def __init__(self):
pass
def build(self, args, root=True):
if args is None:
return None
if isinstance(args, (int, float, str)):
return args
if isinstance(args, list):
if root:
sequence = [self.build(value, root=False) for value in args]
return iaa.Sequential(sequence)
arg_list = [self.to_tuple_if_list(a) for a in args[1:]]
return getattr(iaa, args[0])(*arg_list)
if isinstance(args, dict):
if 'cls' in args:
cls = getattr(iaa, args['cls'])
return cls(
**{
k: self.to_tuple_if_list(v)
for k, v in args.items() if not k == 'cls'
})
else:
return {
key: self.build(value, root=False)
for key, value in args.items()
}
raise RuntimeError('unknown augmenter arg: ' + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
return tuple(obj)
return obj
[docs]@PIPELINES.register_module()
class ImgAug:
"""A wrapper to use imgaug https://github.com/aleju/imgaug.
Args:
args ([list[list|dict]]): The argumentation list. For details, please
refer to imgaug document. Take args=[['Fliplr', 0.5],
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an
example. The args horizontally flip images with probability 0.5,
followed by random rotation with angles in range [-10, 10], and
resize with an independent scale in range [0.5, 3.0] for each
side of images.
clip_invalid_polys (bool): Whether to clip invalid polygons after
transformation. False persists to the behavior in DBNet.
"""
def __init__(self, args=None, clip_invalid_ploys=True):
self.augmenter_args = args
self.augmenter = AugmenterBuilder().build(self.augmenter_args)
self.clip_invalid_polys = clip_invalid_ploys
def __call__(self, results):
# img is bgr
image = results['img']
aug = None
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
results['img'] = aug.augment_image(image)
results['img_shape'] = results['img'].shape
results['flip'] = 'unknown' # it's unknown
results['flip_direction'] = 'unknown' # it's unknown
target_shape = results['img_shape']
self.may_augment_annotation(aug, shape, target_shape, results)
return results
def may_augment_annotation(self, aug, shape, target_shape, results):
if aug is None:
return results
# augment polygon mask
for key in results['mask_fields']:
if self.clip_invalid_polys:
masks = self.may_augment_poly(aug, shape, results[key])
results[key] = PolygonMasks(masks, *target_shape[:2])
else:
masks = self.may_augment_poly_legacy(aug, shape, results[key])
if len(masks) > 0:
results[key] = PolygonMasks(masks, *target_shape[:2])
# augment bbox
for key in results['bbox_fields']:
bboxes = self.may_augment_bbox(aug, shape, results[key])
results[key] = np.zeros(0)
if len(bboxes) > 0:
results[key] = np.stack(bboxes)
return results
def may_augment_bbox(self, aug, ori_shape, bboxes):
imgaug_bboxes = []
for bbox in bboxes:
x1, y1, x2, y2 = bbox
imgaug_bboxes.append(
imgaug.BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2))
imgaug_bboxes = aug.augment_bounding_boxes([
imgaug.BoundingBoxesOnImage(imgaug_bboxes, shape=ori_shape)
])[0].clip_out_of_image()
new_bboxes = []
for box in imgaug_bboxes.bounding_boxes:
new_bboxes.append(
np.array([box.x1, box.y1, box.x2, box.y2], dtype=np.float32))
return new_bboxes
def may_augment_poly(self, aug, img_shape, polys):
imgaug_polys = []
for poly in polys:
poly = poly[0]
poly = poly.reshape(-1, 2)
imgaug_polys.append(imgaug.Polygon(poly))
imgaug_polys = aug.augment_polygons(
[imgaug.PolygonsOnImage(imgaug_polys,
shape=img_shape)])[0].clip_out_of_image()
new_polys = []
for poly in imgaug_polys.polygons:
new_poly = []
for point in poly:
new_poly.append(np.array(point, dtype=np.float32))
new_poly = np.array(new_poly, dtype=np.float32).flatten()
new_polys.append([new_poly])
return new_polys
def may_augment_poly_legacy(self, aug, img_shape, polys):
key_points, poly_point_nums = [], []
for poly in polys:
poly = poly[0]
poly = poly.reshape(-1, 2)
key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly])
poly_point_nums.append(poly.shape[0])
# Warning: we do not clip the out-of-boudnary polygons
key_points = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints=key_points,
shape=img_shape)])[0].keypoints
new_polys = []
start_idx = 0
for poly_point_num in poly_point_nums:
new_poly = []
for key_point in key_points[start_idx:(start_idx +
poly_point_num)]:
new_poly.append([key_point.x, key_point.y])
start_idx += poly_point_num
new_poly = np.array(new_poly).flatten()
new_polys.append([new_poly])
return new_polys
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@PIPELINES.register_module()
class EastRandomCrop:
def __init__(self,
target_size=(640, 640),
max_tries=10,
min_crop_side_ratio=0.1):
self.target_size = target_size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
def __call__(self, results):
# sampling crop
# crop image, boxes, masks
img = results['img']
crop_x, crop_y, crop_w, crop_h = self.crop_area(
img, results['gt_masks'])
scale_w = self.target_size[0] / crop_w
scale_h = self.target_size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
padded_img = np.zeros(
(self.target_size[1], self.target_size[0], img.shape[2]),
img.dtype)
padded_img[:h, :w] = mmcv.imresize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
# for bboxes
for key in results['bbox_fields']:
lines = []
for box in results[key]:
box = box.reshape(2, 2)
poly = ((box - (crop_x, crop_y)) * scale)
if not self.is_poly_outside_rect(poly, 0, 0, w, h):
lines.append(poly.flatten())
results[key] = np.array(lines)
# for masks
for key in results['mask_fields']:
polys = []
polys_label = []
for poly in results[key]:
poly = np.array(poly).reshape(-1, 2)
poly = ((poly - (crop_x, crop_y)) * scale)
if not self.is_poly_outside_rect(poly, 0, 0, w, h):
polys.append([poly])
polys_label.append(0)
results[key] = PolygonMasks(polys, *self.target_size)
if key == 'gt_masks':
results['gt_labels'] = polys_label
results['img'] = padded_img
results['img_shape'] = padded_img.shape
return results
def is_poly_in_rect(self, poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(self, poly, x, y, w, h):
poly = np.array(poly).reshape(-1, 2)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(self, axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(self, axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(self, regions):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(self, img, polys):
h, w, _ = img.shape
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in polys:
points = np.round(
points, decimals=0).astype(np.int32).reshape(-1, 2)
min_x = np.min(points[:, 0])
max_x = np.max(points[:, 0])
w_array[min_x:max_x] = 1
min_y = np.min(points[:, 1])
max_y = np.max(points[:, 1])
h_array[min_y:max_y] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = self.split_regions(h_axis)
w_regions = self.split_regions(w_axis)
for i in range(self.max_tries):
if len(w_regions) > 1:
xmin, xmax = self.region_wise_random_select(w_regions)
else:
xmin, xmax = self.random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = self.region_wise_random_select(h_regions)
else:
ymin, ymax = self.random_select(h_axis, h)
if (xmax - xmin < self.min_crop_side_ratio * w
or ymax - ymin < self.min_crop_side_ratio * h):
# area too small
continue
num_poly_in_rect = 0
for poly in polys:
if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
ymax - ymin):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h