Shortcuts

Source code for mmocr.apis.inferencers.kie_inferencer

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from typing import Any, Dict, List, Optional, Sequence, Union

import mmcv
import mmengine
import numpy as np
from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner.checkpoint import _load_checkpoint

from mmocr.registry import DATASETS
from mmocr.structures import KIEDataSample
from mmocr.utils import ConfigType
from .base_mmocr_inferencer import BaseMMOCRInferencer, ModelType, PredType

InputType = Dict
InputsType = Sequence[Dict]


[docs]class KIEInferencer(BaseMMOCRInferencer): """Key Information Extraction Inferencer. Args: model (str, optional): Path to the config file or the model name defined in metafile. For example, it could be "sdmgr_unet16_60e_wildreceipt" or "configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py". If model is not specified, user must provide the `weights` saved by MMEngine which contains the config string. Defaults to None. weights (str, optional): Path to the checkpoint. If it is not specified and model is a model name of metafile, the weights will be loaded from metafile. Defaults to None. device (str, optional): Device to run inference. If None, the available device will be automatically used. Defaults to None. scope (str, optional): The scope of the model. Defaults to "mmocr". """ def __init__(self, model: Union[ModelType, str, None] = None, weights: Optional[str] = None, device: Optional[str] = None, scope: Optional[str] = 'mmocr') -> None: super().__init__( model=model, weights=weights, device=device, scope=scope) self._load_metainfo_to_visualizer(weights, self.cfg) self.collate_fn = self.kie_collate def _load_metainfo_to_visualizer(self, weights: Optional[str], cfg: ConfigType) -> None: """Load meta information to visualizer.""" if hasattr(self, 'visualizer'): if weights is not None: w = _load_checkpoint(weights, map_location='cpu') if w and 'meta' in w and 'dataset_meta' in w['meta']: self.visualizer.dataset_meta = w['meta']['dataset_meta'] return if 'test_dataloader' in cfg: dataset_cfg = copy.deepcopy(cfg.test_dataloader.dataset) dataset_cfg['lazy_init'] = True dataset_cfg['metainfo'] = None dataset = DATASETS.build(dataset_cfg) self.visualizer.dataset_meta = dataset.metainfo else: raise ValueError( 'KIEVisualizer requires meta information from weights or ' 'test dataset, but none of them is provided.') def _init_pipeline(self, cfg: ConfigType) -> None: """Initialize the test pipeline.""" pipeline_cfg = cfg.test_dataloader.dataset.pipeline idx = self._get_transform_idx(pipeline_cfg, 'LoadKIEAnnotations') if idx == -1: raise ValueError( 'LoadKIEAnnotations is not found in the test pipeline') pipeline_cfg[idx]['with_label'] = False self.novisual = all( self._get_transform_idx(pipeline_cfg, t) == -1 for t in self.loading_transforms) # Remove Resize from test_pipeline, since SDMGR requires bbox # annotations to be resized together with pictures, but visualization # loads the original image from the disk. # TODO: find a more elegant way to fix this idx = self._get_transform_idx(pipeline_cfg, 'Resize') if idx != -1: pipeline_cfg.pop(idx) # If it's in non-visual mode, self.pipeline will be specified. # Otherwise, file_pipeline and ndarray_pipeline will be specified. if self.novisual: return Compose(pipeline_cfg) return super()._init_pipeline(cfg)
[docs] @staticmethod def kie_collate(data_batch: Sequence) -> Any: """A collate function designed for KIE, where the first element (input) is a dict and we only want to keep it as-is instead of batching elements inside. Returns: Any: Transversed Data in the same format as the data_itement of ``data_batch``. """ # noqa: E501 transposed = list(zip(*data_batch)) for i in range(1, len(transposed)): transposed[i] = pseudo_collate(transposed[i]) return transposed
def _inputs_to_list(self, inputs: InputsType) -> list: """Preprocess the inputs to a list. Preprocess inputs to a list according to its type. The inputs can be a dict or list[dict], where each dictionary contains following keys: - img (str or ndarray): Path to the image or the image itself. If KIE Inferencer is used in no-visual mode, this key is not required. Note: If it's an numpy array, it should be in BGR order. - img_shape (tuple(int, int)): Image shape in (H, W). In - instances (list[dict]): A list of instances. - bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box. - text (str): Annotation text. Each ``instance`` looks like the following: .. code-block:: python { # A nested list of 4 numbers representing the bounding box of # the instance, in (x1, y1, x2, y2) order. 'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...], dtype=np.int32), # List of texts. "texts": ['text1', 'text2', ...], } Args: inputs (InputsType): Inputs for the inferencer. Returns: list: List of input for the :meth:`preprocess`. """ processed_inputs = [] if not isinstance(inputs, (list, tuple)): inputs = [inputs] for single_input in inputs: if self.novisual: processed_input = copy.deepcopy(single_input) if 'img' not in single_input and \ 'img_shape' not in single_input: raise ValueError( 'KIEInferencer in no-visual mode ' 'requires input has "img" or "img_shape", but both are' ' not found.') if 'img' in single_input: img = single_input['img'] if isinstance(img, str): img_bytes = mmengine.fileio.get(img) img = mmcv.imfrombytes(img_bytes) processed_input['img'] = img processed_input['img_shape'] = img.shape[:2] processed_inputs.append(processed_input) else: if 'img' not in single_input: raise ValueError( 'This inferencer is constructed to ' 'accept image inputs, but the input does not contain ' '"img" key.') if isinstance(single_input['img'], str): processed_input = { k: v for k, v in single_input.items() if k != 'img' } processed_input['img_path'] = single_input['img'] processed_inputs.append(processed_input) elif isinstance(single_input['img'], np.ndarray): processed_inputs.append(copy.deepcopy(single_input)) else: atype = type(single_input['img']) raise ValueError(f'Unsupported input type: {atype}') return processed_inputs
[docs] def visualize(self, inputs: InputsType, preds: PredType, return_vis: bool = False, show: bool = False, wait_time: int = 0, draw_pred: bool = True, pred_score_thr: float = 0.3, save_vis: bool = False, img_out_dir: str = '') -> Union[List[np.ndarray], None]: """Visualize predictions. Args: inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer. preds (List[Dict]): Predictions of the model. return_vis (bool): Whether to return the visualization result. Defaults to False. show (bool): Whether to display the image in a popup window. Defaults to False. wait_time (float): The interval of show (s). Defaults to 0. draw_pred (bool): Whether to draw predicted bounding boxes. Defaults to True. pred_score_thr (float): Minimum score of bboxes to draw. Defaults to 0.3. save_vis (bool): Whether to save the visualization result. Defaults to False. img_out_dir (str): Output directory of visualization results. If left as empty, no file will be saved. Defaults to ''. Returns: List[np.ndarray] or None: Returns visualization results only if applicable. """ if self.visualizer is None or not (show or save_vis or return_vis): return None if getattr(self, 'visualizer') is None: raise ValueError('Visualization needs the "visualizer" term' 'defined in the config, but got None.') results = [] for single_input, pred in zip(inputs, preds): assert 'img' in single_input or 'img_shape' in single_input if 'img' in single_input: if isinstance(single_input['img'], str): img_bytes = mmengine.fileio.get(single_input['img']) img = mmcv.imfrombytes(img_bytes, channel_order='rgb') elif isinstance(single_input['img'], np.ndarray): img = single_input['img'].copy()[:, :, ::-1] # To RGB elif 'img_shape' in single_input: img = np.zeros(single_input['img_shape'], dtype=np.uint8) else: raise ValueError('Input does not contain either "img" or ' '"img_shape"') img_name = osp.splitext(osp.basename(pred.img_path))[0] if save_vis and img_out_dir: out_file = osp.splitext(img_name)[0] out_file = f'{out_file}.jpg' out_file = osp.join(img_out_dir, out_file) else: out_file = None visualization = self.visualizer.add_datasample( img_name, img, pred, show=show, wait_time=wait_time, draw_gt=False, draw_pred=draw_pred, pred_score_thr=pred_score_thr, out_file=out_file, ) results.append(visualization) return results
[docs] def pred2dict(self, data_sample: KIEDataSample) -> Dict: """Extract elements necessary to represent a prediction into a dictionary. It's better to contain only basic data elements such as strings and numbers in order to guarantee it's json-serializable. Args: data_sample (TextRecogDataSample): The data sample to be converted. Returns: dict: The output dictionary. """ result = {} pred = data_sample.pred_instances result['scores'] = pred.scores.cpu().numpy().tolist() result['edge_scores'] = pred.edge_scores.cpu().numpy().tolist() result['edge_labels'] = pred.edge_labels.cpu().numpy().tolist() result['labels'] = pred.labels.cpu().numpy().tolist() return result
Read the Docs v: dev-1.x
Versions
latest
stable
v1.0.1
v1.0.0
0.x
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.