
Source code for mmocr.datasets.icdar_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset

import mmocr.utils as utils
from mmocr import digit_version
from mmocr.core.evaluation.hmean import eval_hmean

[docs]@DATASETS.register_module() class IcdarDataset(CocoDataset): """Dataset for text detection while ann_file in coco format. Args: ann_file_backend (str): Storage backend for annotation file, should be one in ['disk', 'petrel', 'http']. Default to 'disk'. """ CLASSES = ('text') def __init__(self, ann_file, pipeline, classes=None, data_root=None, img_prefix='', seg_prefix=None, proposal_file=None, test_mode=False, filter_empty_gt=True, select_first_k=-1, ann_file_backend='disk'): # select first k images for fast debugging. self.select_first_k = select_first_k assert ann_file_backend in ['disk', 'petrel', 'http'] self.ann_file_backend = ann_file_backend super().__init__(ann_file, pipeline, classes, data_root, img_prefix, seg_prefix, proposal_file, test_mode, filter_empty_gt) # Set dummy flags just to be compatible with MMDet self.flag = np.zeros(len(self), dtype=np.uint8)
[docs] def load_annotations(self, ann_file): """Load annotation from COCO style annotation file. Args: ann_file (str): Path of annotation file. Returns: list[dict]: Annotation info from COCO api. """ if self.ann_file_backend == 'disk': self.coco = COCO(ann_file) else: mmcv_version = digit_version(mmcv.__version__) if mmcv_version < digit_version('1.3.16'): raise Exception('Please update mmcv to 1.3.16 or higher ' 'to enable "get_local_path" of "FileClient".') file_client = mmcv.FileClient(backend=self.ann_file_backend) with file_client.get_local_path(ann_file) as local_path: self.coco = COCO(local_path) self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} self.img_ids = self.coco.get_img_ids() data_infos = [] count = 0 for i in self.img_ids: info = self.coco.load_imgs([i])[0] info['filename'] = info['file_name'] data_infos.append(info) count = count + 1 if count > self.select_first_k and self.select_first_k > 0: break return data_infos
def _parse_ann_info(self, img_info, ann_info): """Parse bbox and mask annotation. Args: ann_info (list[dict]): Annotation info of an image. Returns: dict: A dict containing the following keys: bboxes, bboxes_ignore, labels, masks, masks_ignore, seg_map. "masks" and "masks_ignore" are represented by polygon boundary point sequences. """ gt_bboxes = [] gt_labels = [] gt_bboxes_ignore = [] gt_masks_ignore = [] gt_masks_ann = [] for ann in ann_info: if ann.get('ignore', False): continue x1, y1, w, h = ann['bbox'] if ann['area'] <= 0 or w < 1 or h < 1: continue if ann['category_id'] not in self.cat_ids: continue bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): gt_bboxes_ignore.append(bbox) gt_masks_ignore.append(ann.get( 'segmentation', None)) # to float32 for latter processing else: gt_bboxes.append(bbox) gt_labels.append(self.cat2label[ann['category_id']]) gt_masks_ann.append(ann.get('segmentation', None)) if gt_bboxes: gt_bboxes = np.array(gt_bboxes, dtype=np.float32) gt_labels = np.array(gt_labels, dtype=np.int64) else: gt_bboxes = np.zeros((0, 4), dtype=np.float32) gt_labels = np.array([], dtype=np.int64) if gt_bboxes_ignore: gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) else: gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) seg_map = img_info['filename'].replace('jpg', 'png') ann = dict( bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore, masks_ignore=gt_masks_ignore, masks=gt_masks_ann, seg_map=seg_map) return ann
[docs] def evaluate(self, results, metric='hmean-iou', logger=None, score_thr=None, min_score_thr=0.3, max_score_thr=0.9, step=0.1, rank_list=None, **kwargs): """Evaluate the hmean metric. Args: results (list[dict]): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | str | None): Logger used for printing related information during evaluation. Default: None. score_thr (float): Deprecated. Please use min_score_thr instead. min_score_thr (float): Minimum score threshold of prediction map. max_score_thr (float): Maximum score threshold of prediction map. step (float): The spacing between score thresholds. rank_list (str): json file used to save eval result of each image after ranking. Returns: dict[dict[str: float]]: The evaluation results. """ assert utils.is_type_list(results, dict) metrics = metric if isinstance(metric, list) else [metric] allowed_metrics = ['hmean-iou', 'hmean-ic13'] metrics = set(metrics) & set(allowed_metrics) img_infos = [] ann_infos = [] for i in range(len(self)): img_info = {'filename': self.data_infos[i]['file_name']} img_infos.append(img_info) ann_infos.append(self.get_ann_info(i)) eval_results = eval_hmean( results, img_infos, ann_infos, metrics=metrics, score_thr=score_thr, min_score_thr=min_score_thr, max_score_thr=max_score_thr, step=step, logger=logger, rank_list=rank_list) return eval_results
Read the Docs v: v0.6.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.