Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.core.evaluation.ner_metric

# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter


def gt_label2entity(gt_infos):
    """Get all entities from ground truth infos.
    Args:
        gt_infos (list[dict]): Ground-truth information contains text and
            label.
    Returns:
        gt_entities (list[list]): Original labeled entities in groundtruth.
                    [[category,start_position,end_position]]
    """
    gt_entities = []
    for gt_info in gt_infos:
        line_entities = []
        label = gt_info['label']
        for key, value in label.items():
            for _, places in value.items():
                for place in places:
                    line_entities.append([key, place[0], place[1]])
        gt_entities.append(line_entities)
    return gt_entities


def _compute_f1(origin, found, right):
    """Calculate recall, precision, f1-score.

    Args:
        origin (int): Original entities in groundtruth.
        found (int): Predicted entities from model.
        right (int): Predicted entities that
                        can match to the original annotation.
    Returns:
        recall (float): Metric of recall.
        precision (float): Metric of precision.
        f1 (float): Metric of f1-score.
    """
    recall = 0 if origin == 0 else (right / origin)
    precision = 0 if found == 0 else (right / found)
    f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (
        precision + recall)
    return recall, precision, f1


def compute_f1_all(pred_entities, gt_entities):
    """Calculate precision, recall and F1-score for all categories.

    Args:
        pred_entities: The predicted entities from model.
        gt_entities: The entities of ground truth file.
    Returns:
        class_info (dict): precision,recall, f1-score in total
                        and each categories.
    """
    origins = []
    founds = []
    rights = []
    for i, _ in enumerate(pred_entities):
        origins.extend(gt_entities[i])
        founds.extend(pred_entities[i])
        rights.extend([
            pre_entity for pre_entity in pred_entities[i]
            if pre_entity in gt_entities[i]
        ])

    class_info = {}
    origin_counter = Counter([x[0] for x in origins])
    found_counter = Counter([x[0] for x in founds])
    right_counter = Counter([x[0] for x in rights])
    for type_, count in origin_counter.items():
        origin = count
        found = found_counter.get(type_, 0)
        right = right_counter.get(type_, 0)
        recall, precision, f1 = _compute_f1(origin, found, right)
        class_info[type_] = {
            'precision': precision,
            'recall': recall,
            'f1-score': f1
        }
    origin = len(origins)
    found = len(founds)
    right = len(rights)
    recall, precision, f1 = _compute_f1(origin, found, right)
    class_info['all'] = {
        'precision': precision,
        'recall': recall,
        'f1-score': f1
    }
    return class_info


[docs]def eval_ner_f1(results, gt_infos): """Evaluate for ner task. Args: results (list): Predict results of entities. gt_infos (list[dict]): Ground-truth information which contains text and label. Returns: class_info (dict): precision,recall, f1-score of total and each catogory. """ assert len(results) == len(gt_infos) gt_entities = gt_label2entity(gt_infos) pred_entities = [] for i, gt_info in enumerate(gt_infos): line_entities = [] for result in results[i]: line_entities.append(result) pred_entities.append(line_entities) assert len(pred_entities) == len(gt_entities) class_info = compute_f1_all(pred_entities, gt_entities) return class_info
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.