Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.core.evaluation.ocr_metric

# Copyright (c) OpenMMLab. All rights reserved.
import re
from difflib import SequenceMatcher

from rapidfuzz.distance import Levenshtein

from mmocr.utils import is_type_list


def cal_true_positive_char(pred, gt):
    """Calculate correct character number in prediction.

    Args:
        pred (str): Prediction text.
        gt (str): Ground truth text.

    Returns:
        true_positive_char_num (int): The true positive number.
    """

    all_opt = SequenceMatcher(None, pred, gt)
    true_positive_char_num = 0
    for opt, _, _, s2, e2 in all_opt.get_opcodes():
        if opt == 'equal':
            true_positive_char_num += (e2 - s2)
        else:
            pass
    return true_positive_char_num


def count_matches(pred_texts, gt_texts):
    """Count the various match number for metric calculation.

    Args:
        pred_texts (list[str]): Predicted text string.
        gt_texts (list[str]): Ground truth text string.

    Returns:
        match_res: (dict[str: int]): Match number used for
            metric calculation.
    """
    match_res = {
        'gt_char_num': 0,
        'pred_char_num': 0,
        'true_positive_char_num': 0,
        'gt_word_num': 0,
        'match_word_num': 0,
        'match_word_ignore_case': 0,
        'match_word_ignore_case_symbol': 0
    }
    comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
    norm_ed_sum = 0.0
    for pred_text, gt_text in zip(pred_texts, gt_texts):
        if gt_text == pred_text:
            match_res['match_word_num'] += 1
        gt_text_lower = gt_text.lower()
        pred_text_lower = pred_text.lower()
        if gt_text_lower == pred_text_lower:
            match_res['match_word_ignore_case'] += 1
        gt_text_lower_ignore = comp.sub('', gt_text_lower)
        pred_text_lower_ignore = comp.sub('', pred_text_lower)
        if gt_text_lower_ignore == pred_text_lower_ignore:
            match_res['match_word_ignore_case_symbol'] += 1
        match_res['gt_word_num'] += 1

        norm_ed_sum += Levenshtein.normalized_distance(pred_text_lower_ignore,
                                                       gt_text_lower_ignore)

        # number to calculate char level recall & precision
        match_res['gt_char_num'] += len(gt_text_lower_ignore)
        match_res['pred_char_num'] += len(pred_text_lower_ignore)
        true_positive_char_num = cal_true_positive_char(
            pred_text_lower_ignore, gt_text_lower_ignore)
        match_res['true_positive_char_num'] += true_positive_char_num

    normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
    match_res['ned'] = normalized_edit_distance

    return match_res


[docs]def eval_ocr_metric(pred_texts, gt_texts, metric='acc'): """Evaluate the text recognition performance with metric: word accuracy and 1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details. Args: pred_texts (list[str]): Text strings of prediction. gt_texts (list[str]): Text strings of ground truth. metric (str | list[str]): Metric(s) to be evaluated. Options are: - 'word_acc': Accuracy at word level. - 'word_acc_ignore_case': Accuracy at word level, ignoring letter case. - 'word_acc_ignore_case_symbol': Accuracy at word level, ignoring letter case and symbol. (Default metric for academic evaluation) - 'char_recall': Recall at character level, ignoring letter case and symbol. - 'char_precision': Precision at character level, ignoring letter case and symbol. - 'one_minus_ned': 1 - normalized_edit_distance In particular, if ``metric == 'acc'``, results on all metrics above will be reported. Returns: dict{str: float}: Result dict for text recognition, keys could be some of the following: ['word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol', 'char_recall', 'char_precision', '1-N.E.D']. """ assert isinstance(pred_texts, list) assert isinstance(gt_texts, list) assert len(pred_texts) == len(gt_texts) assert isinstance(metric, str) or is_type_list(metric, str) if metric == 'acc' or metric == ['acc']: metric = [ 'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol', 'char_recall', 'char_precision', 'one_minus_ned' ] metric = set([metric]) if isinstance(metric, str) else set(metric) supported_metrics = set([ 'word_acc', 'word_acc_ignore_case', 'word_acc_ignore_case_symbol', 'char_recall', 'char_precision', 'one_minus_ned' ]) assert metric.issubset(supported_metrics) match_res = count_matches(pred_texts, gt_texts) eps = 1e-8 eval_res = {} if 'char_recall' in metric: char_recall = 1.0 * match_res['true_positive_char_num'] / ( eps + match_res['gt_char_num']) eval_res['char_recall'] = char_recall if 'char_precision' in metric: char_precision = 1.0 * match_res['true_positive_char_num'] / ( eps + match_res['pred_char_num']) eval_res['char_precision'] = char_precision if 'word_acc' in metric: word_acc = 1.0 * match_res['match_word_num'] / ( eps + match_res['gt_word_num']) eval_res['word_acc'] = word_acc if 'word_acc_ignore_case' in metric: word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / ( eps + match_res['gt_word_num']) eval_res['word_acc_ignore_case'] = word_acc_ignore_case if 'word_acc_ignore_case_symbol' in metric: word_acc_ignore_case_symbol = 1.0 * match_res[ 'match_word_ignore_case_symbol'] / ( eps + match_res['gt_word_num']) eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol if 'one_minus_ned' in metric: eval_res['1-N.E.D'] = 1.0 - match_res['ned'] for key, value in eval_res.items(): eval_res[key] = float('{:.4f}'.format(value)) return eval_res
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.