Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.core.evaluation.ner_metric
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
def gt_label2entity(gt_infos):
"""Get all entities from ground truth infos.
Args:
gt_infos (list[dict]): Ground-truth information contains text and
label.
Returns:
gt_entities (list[list]): Original labeled entities in groundtruth.
[[category,start_position,end_position]]
"""
gt_entities = []
for gt_info in gt_infos:
line_entities = []
label = gt_info['label']
for key, value in label.items():
for _, places in value.items():
for place in places:
line_entities.append([key, place[0], place[1]])
gt_entities.append(line_entities)
return gt_entities
def _compute_f1(origin, found, right):
"""Calculate recall, precision, f1-score.
Args:
origin (int): Original entities in groundtruth.
found (int): Predicted entities from model.
right (int): Predicted entities that
can match to the original annotation.
Returns:
recall (float): Metric of recall.
precision (float): Metric of precision.
f1 (float): Metric of f1-score.
"""
recall = 0 if origin == 0 else (right / origin)
precision = 0 if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (
precision + recall)
return recall, precision, f1
def compute_f1_all(pred_entities, gt_entities):
"""Calculate precision, recall and F1-score for all categories.
Args:
pred_entities: The predicted entities from model.
gt_entities: The entities of ground truth file.
Returns:
class_info (dict): precision,recall, f1-score in total
and each categories.
"""
origins = []
founds = []
rights = []
for i, _ in enumerate(pred_entities):
origins.extend(gt_entities[i])
founds.extend(pred_entities[i])
rights.extend([
pre_entity for pre_entity in pred_entities[i]
if pre_entity in gt_entities[i]
])
class_info = {}
origin_counter = Counter([x[0] for x in origins])
found_counter = Counter([x[0] for x in founds])
right_counter = Counter([x[0] for x in rights])
for type_, count in origin_counter.items():
origin = count
found = found_counter.get(type_, 0)
right = right_counter.get(type_, 0)
recall, precision, f1 = _compute_f1(origin, found, right)
class_info[type_] = {
'precision': precision,
'recall': recall,
'f1-score': f1
}
origin = len(origins)
found = len(founds)
right = len(rights)
recall, precision, f1 = _compute_f1(origin, found, right)
class_info['all'] = {
'precision': precision,
'recall': recall,
'f1-score': f1
}
return class_info
[docs]def eval_ner_f1(results, gt_infos):
"""Evaluate for ner task.
Args:
results (list): Predict results of entities.
gt_infos (list[dict]): Ground-truth information which contains
text and label.
Returns:
class_info (dict): precision,recall, f1-score of total
and each catogory.
"""
assert len(results) == len(gt_infos)
gt_entities = gt_label2entity(gt_infos)
pred_entities = []
for i, gt_info in enumerate(gt_infos):
line_entities = []
for result in results[i]:
line_entities.append(result)
pred_entities.append(line_entities)
assert len(pred_entities) == len(gt_entities)
class_info = compute_f1_all(pred_entities, gt_entities)
return class_info