Shortcuts

Source code for mmocr.models.textdet.detectors.mmdet_wrapper

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import cv2
import torch
from mmdet.structures import DetDataSample
from mmdet.structures import SampleList as MMDET_SampleList
from mmdet.structures.mask import bitmap_to_polygon
from mmengine.model import BaseModel
from mmengine.structures import InstanceData

from mmocr.registry import MODELS
from mmocr.utils.bbox_utils import bbox2poly
from mmocr.utils.typing_utils import DetSampleList

ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample],
                       Tuple[torch.Tensor], torch.Tensor]


[docs]@MODELS.register_module() class MMDetWrapper(BaseModel): """A wrapper of MMDet's model. Args: cfg (dict): The config of the model. text_repr_type (str): The boundary encoding type 'poly' or 'quad'. Defaults to 'poly'. """ def __init__(self, cfg: Dict, text_repr_type: str = 'poly') -> None: data_preprocessor = cfg.pop('data_preprocessor') data_preprocessor.update(_scope_='mmdet') super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) cfg['_scope_'] = 'mmdet' self.wrapped_model = MODELS.build(cfg) self.text_repr_type = text_repr_type
[docs] def forward(self, inputs: torch.Tensor, data_samples: Optional[Union[DetSampleList, MMDET_SampleList]] = None, mode: str = 'tensor', **kwargs) -> ForwardResults: """The unified entry for a forward process in both training and test. The method works in three modes: "tensor", "predict" and "loss": - "tensor": Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - "predict": Forward and return the predictions, which are fully processed to a list of :obj:`DetDataSample`. - "loss": Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle either back propagation or parameter update, which are supposed to be done in :meth:`train_step`. Args: inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. data_samples (list[:obj:`DetDataSample`] or list[:obj:`TextDetDataSample`]): The annotation data of every sample. When in "predict" mode, it should be a list of :obj:`TextDetDataSample`. Otherwise they are :obj:`DetDataSample`s. Defaults to None. mode (str): Running mode. Defaults to 'tensor'. Returns: The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="predict"``, return a list of :obj:`TextDetDataSample`. - If ``mode="loss"``, return a dict of tensor. """ if mode == 'predict': ocr_data_samples = data_samples data_samples = [] for i in range(len(ocr_data_samples)): data_samples.append( DetDataSample(metainfo=ocr_data_samples[i].metainfo)) results = self.wrapped_model.forward(inputs, data_samples, mode, **kwargs) if mode == 'predict': results = self.adapt_predictions(results, ocr_data_samples) return results
[docs] def adapt_predictions(self, data: MMDET_SampleList, data_samples: DetSampleList) -> DetSampleList: """Convert Instance datas from MMDet into MMOCR's format. Args: data: (list[DetDataSample]): Detection results of the input images. Each DetDataSample usually contain 'pred_instances'. And the ``pred_instances`` usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor, Optional): Has a shape (num_instances, H, W). data_samples (list[:obj:`TextDetDataSample`]): The annotation data of every samples. Returns: list[TextDetDataSample]: A list of N datasamples containing ground truth and prediction results. The polygon results are saved in ``TextDetDataSample.pred_instances.polygons`` The confidence scores are saved in ``TextDetDataSample.pred_instances.scores``. """ for i, det_data_sample in enumerate(data): data_samples[i].pred_instances = InstanceData() # convert mask to polygons if mask exists if 'masks' in det_data_sample.pred_instances.keys(): masks = det_data_sample.pred_instances.masks.cpu().numpy() polygons = [] scores = [] for mask_idx, mask in enumerate(masks): contours, _ = bitmap_to_polygon(mask) polygons += [contour.reshape(-1) for contour in contours] scores += [ det_data_sample.pred_instances.scores[mask_idx].cpu() ] * len(contours) # filter invalid polygons filterd_polygons = [] keep_idx = [] for poly_idx, polygon in enumerate(polygons): if len(polygon) < 6: continue filterd_polygons.append(polygon) keep_idx.append(poly_idx) # convert by text_repr_type if self.text_repr_type == 'quad': for j, poly in enumerate(filterd_polygons): rect = cv2.minAreaRect(poly) vertices = cv2.boxPoints(rect) poly = vertices.flatten() filterd_polygons[j] = poly data_samples[i].pred_instances.polygons = filterd_polygons data_samples[i].pred_instances.scores = torch.FloatTensor( scores)[keep_idx] else: bboxes = det_data_sample.pred_instances.bboxes.cpu().numpy() polygons = [bbox2poly(bbox) for bbox in bboxes] data_samples[i].pred_instances.polygons = polygons data_samples[i].pred_instances.scores = torch.FloatTensor( det_data_sample.pred_instances.scores.cpu()) return data_samples
Read the Docs v: dev-1.x
Versions
latest
stable
v1.0.1
v1.0.0
0.x
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.