Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.datasets.icdar_dataset
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset
import mmocr.utils as utils
from mmocr import digit_version
from mmocr.core.evaluation.hmean import eval_hmean
[docs]@DATASETS.register_module()
class IcdarDataset(CocoDataset):
"""Dataset for text detection while ann_file in coco format.
Args:
ann_file_backend (str): Storage backend for annotation file,
should be one in ['disk', 'petrel', 'http']. Default to 'disk'.
"""
CLASSES = ('text')
def __init__(self,
ann_file,
pipeline,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True,
select_first_k=-1,
ann_file_backend='disk'):
# select first k images for fast debugging.
self.select_first_k = select_first_k
assert ann_file_backend in ['disk', 'petrel', 'http']
self.ann_file_backend = ann_file_backend
super().__init__(
ann_file=ann_file,
pipeline=pipeline,
classes=classes,
data_root=data_root,
img_prefix=img_prefix,
seg_prefix=seg_prefix,
proposal_file=proposal_file,
test_mode=test_mode,
filter_empty_gt=filter_empty_gt)
# Set dummy flags just to be compatible with MMDet
self.flag = np.zeros(len(self), dtype=np.uint8)
[docs] def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
if self.ann_file_backend == 'disk':
self.coco = COCO(ann_file)
else:
mmcv_version = digit_version(mmcv.__version__)
if mmcv_version < digit_version('1.3.16'):
raise Exception('Please update mmcv to 1.3.16 or higher '
'to enable "get_local_path" of "FileClient".')
file_client = mmcv.FileClient(backend=self.ann_file_backend)
with file_client.get_local_path(ann_file) as local_path:
self.coco = COCO(local_path)
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.img_ids = self.coco.get_img_ids()
data_infos = []
count = 0
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
count = count + 1
if count > self.select_first_k and self.select_first_k > 0:
break
return data_infos
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, masks_ignore, seg_map. "masks" and
"masks_ignore" are represented by polygon boundary
point sequences.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ignore = []
gt_masks_ann = []
for ann in ann_info:
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
gt_bboxes_ignore.append(bbox)
gt_masks_ignore.append(ann.get(
'segmentation', None)) # to float32 for latter processing
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
gt_masks_ann.append(ann.get('segmentation', None))
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
seg_map = img_info['filename'].replace('jpg', 'png')
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
bboxes_ignore=gt_bboxes_ignore,
masks_ignore=gt_masks_ignore,
masks=gt_masks_ann,
seg_map=seg_map)
return ann
[docs] def evaluate(self,
results,
metric='hmean-iou',
logger=None,
score_thr=None,
min_score_thr=0.3,
max_score_thr=0.9,
step=0.1,
rank_list=None,
**kwargs):
"""Evaluate the hmean metric.
Args:
results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
score_thr (float): Deprecated. Please use min_score_thr instead.
min_score_thr (float): Minimum score threshold of prediction map.
max_score_thr (float): Maximum score threshold of prediction map.
step (float): The spacing between score thresholds.
rank_list (str): json file used to save eval result
of each image after ranking.
Returns:
dict[dict[str: float]]: The evaluation results.
"""
assert utils.is_type_list(results, dict)
metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['hmean-iou', 'hmean-ic13']
metrics = set(metrics) & set(allowed_metrics)
img_infos = []
ann_infos = []
for i in range(len(self)):
img_info = {'filename': self.data_infos[i]['file_name']}
img_infos.append(img_info)
ann_infos.append(self.get_ann_info(i))
eval_results = eval_hmean(
results,
img_infos,
ann_infos,
metrics=metrics,
score_thr=score_thr,
min_score_thr=min_score_thr,
max_score_thr=max_score_thr,
step=step,
logger=logger,
rank_list=rank_list)
return eval_results