Shortcuts

Source code for mmocr.datasets.openset_kie_dataset

import copy

import numpy as np
import torch
from mmdet.datasets.builder import DATASETS

from mmocr.datasets import KIEDataset


[docs]@DATASETS.register_module() class OpensetKIEDataset(KIEDataset): """Openset KIE classifies the nodes (i.e. text boxes) into bg/key/value categories, and additionally learns key-value relationship among nodes. Args: ann_file (str): Annotation file path. loader (dict): Dictionary to construct loader to load annotation infos. dict_file (str): Character dict file path. img_prefix (str, optional): Image prefix to generate full image path. pipeline (list[dict]): Processing pipeline. norm (float): Norm to map value from one range to another. link_type (str): ``one-to-one`` | ``one-to-many`` | ``many-to-one`` | ``many-to-many``. For ``many-to-many``, one key box can have many values and vice versa. edge_thr (float): Score threshold for a valid edge. test_mode (bool, optional): If True, try...except will be turned off in __getitem__. key_node_idx (int): Index of key in node classes. value_node_idx (int): Index of value in node classes. node_classes (int): Number of node classes. """ def __init__(self, ann_file, loader, dict_file, img_prefix='', pipeline=None, norm=10., link_type='one-to-one', edge_thr=0.5, test_mode=True, key_node_idx=1, value_node_idx=2, node_classes=4): super().__init__(ann_file, loader, dict_file, img_prefix, pipeline, norm, False, test_mode) assert link_type in [ 'one-to-one', 'one-to-many', 'many-to-one', 'many-to-many', 'none' ] self.link_type = link_type self.data_dict = {x['file_name']: x for x in self.data_infos} self.edge_thr = edge_thr self.key_node_idx = key_node_idx self.value_node_idx = value_node_idx self.node_classes = node_classes
[docs] def pre_pipeline(self, results): super().pre_pipeline(results) results['ori_texts'] = results['ann_info']['ori_texts'] results['ori_boxes'] = results['ann_info']['ori_boxes']
[docs] def list_to_numpy(self, ann_infos): results = super().list_to_numpy(ann_infos) results.update(dict(ori_texts=ann_infos['texts'])) results.update(dict(ori_boxes=ann_infos['boxes'])) return results
[docs] def evaluate(self, results, metric='openset_f1', metric_options=None, **kwargs): # Protect ``metric_options`` since it uses mutable value as default metric_options = copy.deepcopy(metric_options) metrics = metric if isinstance(metric, list) else [metric] allowed_metrics = ['openset_f1'] for m in metrics: if m not in allowed_metrics: raise KeyError(f'metric {m} is not supported') preds, gts = [], [] for result in results: # data for preds pred = self.decode_pred(result) preds.append(pred) # data for gts gt = self.decode_gt(pred['filename']) gts.append(gt) return self.compute_openset_f1(preds, gts)
def _decode_pairs_gt(self, labels, edge_ids): """Find all pairs in gt. The first index in the pair (n1, n2) is key. """ gt_pairs = [] for i, label in enumerate(labels): if label == self.key_node_idx: for j, edge_id in enumerate(edge_ids): if edge_id == edge_ids[i] and labels[ j] == self.value_node_idx: gt_pairs.append((i, j)) return gt_pairs @staticmethod def _decode_pairs_pred(nodes, labels, edges, edge_thr=0.5, link_type='one-to-one'): """Find all pairs in prediction. The first index in the pair (n1, n2) is more likely to be a key according to prediction in nodes. """ edges = torch.max(edges, edges.T) if link_type in ['none', 'many-to-many']: pair_inds = (edges > edge_thr).nonzero(as_tuple=True) pred_pairs = [(n1.item(), n2.item()) if nodes[n1, 1] > nodes[n1, 2] else (n2.item(), n1.item()) for n1, n2 in zip(*pair_inds) if n1 < n2] pred_pairs = [(i, j) for i, j in pred_pairs if labels[i] == 1 and labels[j] == 2] else: links = edges.clone() links[links <= edge_thr] = -1 links[labels != 1, :] = -1 links[:, labels != 2] = -1 pred_pairs = [] while (links > -1).any(): i, j = np.unravel_index(torch.argmax(links), links.shape) pred_pairs.append((i, j)) if link_type == 'one-to-one': links[i, :] = -1 links[:, j] = -1 elif link_type == 'one-to-many': links[:, j] = -1 elif link_type == 'many-to-one': links[i, :] = -1 else: raise ValueError(f'not supported link type {link_type}') pairs_conf = [edges[i, j].item() for i, j in pred_pairs] return pred_pairs, pairs_conf
[docs] def decode_pred(self, result): """Decode prediction. Assemble boxes and predicted labels into bboxes, and convert edges into matrix. """ filename = result['img_metas'][0]['ori_filename'] nodes = result['nodes'].cpu() labels_conf, labels = torch.max(nodes, dim=-1) num_nodes = nodes.size(0) edges = result['edges'][:, -1].view(num_nodes, num_nodes).cpu() annos = self.data_dict[filename]['annotations'] boxes = [x['box'] for x in annos] texts = [x['text'] for x in annos] bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]] bboxes = torch.cat([bboxes, labels[:, None].float()], -1) pairs, pairs_conf = self._decode_pairs_pred(nodes, labels, edges, self.edge_thr, self.link_type) pred = { 'filename': filename, 'boxes': boxes, 'bboxes': bboxes.tolist(), 'labels': labels.tolist(), 'labels_conf': labels_conf.tolist(), 'texts': texts, 'pairs': pairs, 'pairs_conf': pairs_conf } return pred
[docs] def decode_gt(self, filename): """Decode ground truth. Assemble boxes and labels into bboxes. """ annos = self.data_dict[filename]['annotations'] labels = torch.Tensor([x['label'] for x in annos]) texts = [x['text'] for x in annos] edge_ids = [x['edge'] for x in annos] boxes = [x['box'] for x in annos] bboxes = torch.Tensor(boxes)[:, [0, 1, 4, 5]] bboxes = torch.cat([bboxes, labels[:, None].float()], -1) pairs = self._decode_pairs_gt(labels, edge_ids) gt = { 'filename': filename, 'boxes': boxes, 'bboxes': bboxes.tolist(), 'labels': labels.tolist(), 'labels_conf': [1. for _ in labels], 'texts': texts, 'pairs': pairs, 'pairs_conf': [1. for _ in pairs] } return gt
[docs] def compute_openset_f1(self, preds, gts): """Compute openset macro-f1 and micro-f1 score. Args: preds: (list[dict]): List of prediction results, including keys: ``filename``, ``pairs``, etc. gts: (list[dict]): List of ground-truth infos, including keys: ``filename``, ``pairs``, etc. Returns: dict: Evaluation result with keys: ``node_openset_micro_f1``, \ ``node_openset_macro_f1``, ``edge_openset_f1``. """ total_edge_hit_num, total_edge_gt_num, total_edge_pred_num = 0, 0, 0 total_node_hit_num, total_node_gt_num, total_node_pred_num = {}, {}, {} node_inds = list(range(self.node_classes)) for node_idx in node_inds: total_node_hit_num[node_idx] = 0 total_node_gt_num[node_idx] = 0 total_node_pred_num[node_idx] = 0 img_level_res = {} for pred, gt in zip(preds, gts): filename = pred['filename'] img_res = {} # edge metric related pairs_pred = pred['pairs'] pairs_gt = gt['pairs'] img_res['edge_hit_num'] = 0 for pair in pairs_gt: if pair in pairs_pred: img_res['edge_hit_num'] += 1 img_res['edge_recall'] = 1.0 * img_res['edge_hit_num'] / max( 1, len(pairs_gt)) img_res['edge_precision'] = 1.0 * img_res['edge_hit_num'] / max( 1, len(pairs_pred)) img_res['f1'] = 2 * img_res['edge_recall'] * img_res[ 'edge_precision'] / max( 1, img_res['edge_recall'] + img_res['edge_precision']) total_edge_hit_num += img_res['edge_hit_num'] total_edge_gt_num += len(pairs_gt) total_edge_pred_num += len(pairs_pred) # node metric related nodes_pred = pred['labels'] nodes_gt = gt['labels'] for i, node_gt in enumerate(nodes_gt): node_gt = int(node_gt) total_node_gt_num[node_gt] += 1 if nodes_pred[i] == node_gt: total_node_hit_num[node_gt] += 1 for node_pred in nodes_pred: total_node_pred_num[node_pred] += 1 img_level_res[filename] = img_res stats = {} # edge f1 total_edge_recall = 1.0 * total_edge_hit_num / max( 1, total_edge_gt_num) total_edge_precision = 1.0 * total_edge_hit_num / max( 1, total_edge_pred_num) edge_f1 = 2 * total_edge_recall * total_edge_precision / max( 1, total_edge_recall + total_edge_precision) stats = {'edge_openset_f1': edge_f1} # node f1 cared_node_hit_num, cared_node_gt_num, cared_node_pred_num = 0, 0, 0 node_macro_metric = {} for node_idx in node_inds: if node_idx < 1 or node_idx > 2: continue cared_node_hit_num += total_node_hit_num[node_idx] cared_node_gt_num += total_node_gt_num[node_idx] cared_node_pred_num += total_node_pred_num[node_idx] node_res = {} node_res['recall'] = 1.0 * total_node_hit_num[node_idx] / max( 1, total_node_gt_num[node_idx]) node_res['precision'] = 1.0 * total_node_hit_num[node_idx] / max( 1, total_node_pred_num[node_idx]) node_res[ 'f1'] = 2 * node_res['recall'] * node_res['precision'] / max( 1, node_res['recall'] + node_res['precision']) node_macro_metric[node_idx] = node_res node_micro_recall = 1.0 * cared_node_hit_num / max( 1, cared_node_gt_num) node_micro_precision = 1.0 * cared_node_hit_num / max( 1, cared_node_pred_num) node_micro_f1 = 2 * node_micro_recall * node_micro_precision / max( 1, node_micro_recall + node_micro_precision) stats['node_openset_micro_f1'] = node_micro_f1 stats['node_openset_macro_f1'] = np.mean( [v['f1'] for k, v in node_macro_metric.items()]) return stats
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.