Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.datasets.icdar_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.coco import CocoDataset

import mmocr.utils as utils
from mmocr import digit_version
from mmocr.core.evaluation.hmean import eval_hmean


[docs]@DATASETS.register_module() class IcdarDataset(CocoDataset): """Dataset for text detection while ann_file in coco format. Args: ann_file_backend (str): Storage backend for annotation file, should be one in ['disk', 'petrel', 'http']. Default to 'disk'. """ CLASSES = ('text') def __init__(self, ann_file, pipeline, classes=None, data_root=None, img_prefix='', seg_prefix=None, proposal_file=None, test_mode=False, filter_empty_gt=True, select_first_k=-1, ann_file_backend='disk'): # select first k images for fast debugging. self.select_first_k = select_first_k assert ann_file_backend in ['disk', 'petrel', 'http'] self.ann_file_backend = ann_file_backend super().__init__( ann_file=ann_file, pipeline=pipeline, classes=classes, data_root=data_root, img_prefix=img_prefix, seg_prefix=seg_prefix, proposal_file=proposal_file, test_mode=test_mode, filter_empty_gt=filter_empty_gt) # Set dummy flags just to be compatible with MMDet self.flag = np.zeros(len(self), dtype=np.uint8)
[docs] def load_annotations(self, ann_file): """Load annotation from COCO style annotation file. Args: ann_file (str): Path of annotation file. Returns: list[dict]: Annotation info from COCO api. """ if self.ann_file_backend == 'disk': self.coco = COCO(ann_file) else: mmcv_version = digit_version(mmcv.__version__) if mmcv_version < digit_version('1.3.16'): raise Exception('Please update mmcv to 1.3.16 or higher ' 'to enable "get_local_path" of "FileClient".') file_client = mmcv.FileClient(backend=self.ann_file_backend) with file_client.get_local_path(ann_file) as local_path: self.coco = COCO(local_path) self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} self.img_ids = self.coco.get_img_ids() data_infos = [] count = 0 for i in self.img_ids: info = self.coco.load_imgs([i])[0] info['filename'] = info['file_name'] data_infos.append(info) count = count + 1 if count > self.select_first_k and self.select_first_k > 0: break return data_infos
def _parse_ann_info(self, img_info, ann_info): """Parse bbox and mask annotation. Args: ann_info (list[dict]): Annotation info of an image. Returns: dict: A dict containing the following keys: bboxes, bboxes_ignore, labels, masks, masks_ignore, seg_map. "masks" and "masks_ignore" are represented by polygon boundary point sequences. """ gt_bboxes = [] gt_labels = [] gt_bboxes_ignore = [] gt_masks_ignore = [] gt_masks_ann = [] for ann in ann_info: if ann.get('ignore', False): continue x1, y1, w, h = ann['bbox'] if ann['area'] <= 0 or w < 1 or h < 1: continue if ann['category_id'] not in self.cat_ids: continue bbox = [x1, y1, x1 + w, y1 + h] if ann.get('iscrowd', False): gt_bboxes_ignore.append(bbox) gt_masks_ignore.append(ann.get( 'segmentation', None)) # to float32 for latter processing else: gt_bboxes.append(bbox) gt_labels.append(self.cat2label[ann['category_id']]) gt_masks_ann.append(ann.get('segmentation', None)) if gt_bboxes: gt_bboxes = np.array(gt_bboxes, dtype=np.float32) gt_labels = np.array(gt_labels, dtype=np.int64) else: gt_bboxes = np.zeros((0, 4), dtype=np.float32) gt_labels = np.array([], dtype=np.int64) if gt_bboxes_ignore: gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) else: gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) seg_map = img_info['filename'].replace('jpg', 'png') ann = dict( bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore, masks_ignore=gt_masks_ignore, masks=gt_masks_ann, seg_map=seg_map) return ann
[docs] def evaluate(self, results, metric='hmean-iou', logger=None, score_thr=None, min_score_thr=0.3, max_score_thr=0.9, step=0.1, rank_list=None, **kwargs): """Evaluate the hmean metric. Args: results (list[dict]): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | str | None): Logger used for printing related information during evaluation. Default: None. score_thr (float): Deprecated. Please use min_score_thr instead. min_score_thr (float): Minimum score threshold of prediction map. max_score_thr (float): Maximum score threshold of prediction map. step (float): The spacing between score thresholds. rank_list (str): json file used to save eval result of each image after ranking. Returns: dict[dict[str: float]]: The evaluation results. """ assert utils.is_type_list(results, dict) metrics = metric if isinstance(metric, list) else [metric] allowed_metrics = ['hmean-iou', 'hmean-ic13'] metrics = set(metrics) & set(allowed_metrics) img_infos = [] ann_infos = [] for i in range(len(self)): img_info = {'filename': self.data_infos[i]['file_name']} img_infos.append(img_info) ann_infos.append(self.get_ann_info(i)) eval_results = eval_hmean( results, img_infos, ann_infos, metrics=metrics, score_thr=score_thr, min_score_thr=min_score_thr, max_score_thr=max_score_thr, step=step, logger=logger, rank_list=rank_list) return eval_results
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.