Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.datasets.ocr_seg_dataset
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets.builder import DATASETS
import mmocr.utils as utils
from mmocr.datasets.ocr_dataset import OCRDataset
[docs]@DATASETS.register_module()
class OCRSegDataset(OCRDataset):
def _parse_anno_info(self, annotations):
"""Parse char boxes annotations.
Args:
annotations (list[dict]): Annotations of one image, where
each dict is for one character.
Returns:
dict: A dict containing the following keys:
- chars (list[str]): List of character strings.
- char_rects (list[list[float]]): List of char box, with each
in style of rectangle: [x_min, y_min, x_max, y_max].
- char_quads (list[list[float]]): List of char box, with each
in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4].
"""
assert utils.is_type_list(annotations, dict)
assert 'char_box' in annotations[0]
assert 'char_text' in annotations[0]
assert len(annotations[0]['char_box']) in [4, 8]
chars, char_rects, char_quads = [], [], []
for ann in annotations:
char_box = ann['char_box']
if len(char_box) == 4:
char_box_type = ann.get('char_box_type', 'xyxy')
if char_box_type == 'xyxy':
char_rects.append(char_box)
char_quads.append([
char_box[0], char_box[1], char_box[2], char_box[1],
char_box[2], char_box[3], char_box[0], char_box[3]
])
elif char_box_type == 'xywh':
x1, y1, w, h = char_box
x2 = x1 + w
y2 = y1 + h
char_rects.append([x1, y1, x2, y2])
char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2])
else:
raise ValueError(f'invalid char_box_type {char_box_type}')
elif len(char_box) == 8:
x_list, y_list = [], []
for i in range(4):
x_list.append(char_box[2 * i])
y_list.append(char_box[2 * i + 1])
x_max, x_min = max(x_list), min(x_list)
y_max, y_min = max(y_list), min(y_list)
char_rects.append([x_min, y_min, x_max, y_max])
char_quads.append(char_box)
else:
raise Exception(
f'invalid num in char box: {len(char_box)} not in (4, 8)')
chars.append(ann['char_text'])
ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads)
return ann
[docs] def prepare_train_img(self, index):
"""Get training data and annotations from pipeline.
Args:
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]
img_info = {
'filename': img_ann_info['file_name'],
}
ann_info = self._parse_anno_info(img_ann_info['annotations'])
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)