Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textrecog.convertors.seg
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
import torch
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from .base import BaseConvertor
[docs]@CONVERTORS.register_module()
class SegConvertor(BaseConvertor):
"""Convert between text, index and tensor for segmentation based pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_file (None|str): Character dict file path. If not none, the
file is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
lower (bool): If True, convert original string to lower case.
"""
def __init__(self,
dict_type='DICT36',
dict_file=None,
dict_list=None,
with_unknown=True,
lower=False,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.lower = lower
self.update_dict()
def update_dict(self):
# background
self.idx2char.insert(0, '<BG>')
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append('<UKN>')
self.unknown_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
[docs] def tensor2str(self, output, img_metas=None):
"""Convert model output tensor to string labels.
Args:
output (tensor): Model outputs with size: N * C * H * W
img_metas (list[dict]): Each dict contains one image info.
Returns:
texts (list[str]): Decoded text labels.
scores (list[list[float]]): Decoded chars scores.
"""
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == output.size(0)
texts, scores = [], []
for b in range(output.size(0)):
seg_pred = output[b].detach()
valid_width = int(
output.size(-1) * img_metas[b]['valid_ratio'] + 1)
seg_res = torch.argmax(
seg_pred[:, :, :valid_width],
dim=0).cpu().numpy().astype(np.int32)
seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
seg_thr)
component_num = stats.shape[0]
all_res = []
for i in range(component_num):
temp_loc = (labels == i)
temp_value = seg_res[temp_loc]
temp_center = centroids[i]
temp_max_num = 0
temp_max_cls = -1
temp_total_num = 0
for c in range(len(self.idx2char)):
c_num = np.sum(temp_value == c)
temp_total_num += c_num
if c_num > temp_max_num:
temp_max_num = c_num
temp_max_cls = c
if temp_max_cls == 0:
continue
temp_max_score = 1.0 * temp_max_num / temp_total_num
all_res.append(
[temp_max_cls, temp_center, temp_max_num, temp_max_score])
all_res = sorted(all_res, key=lambda s: s[1][0])
chars, char_scores = [], []
for res in all_res:
temp_area = res[2]
if temp_area < 20:
continue
temp_char_index = res[0]
if temp_char_index >= len(self.idx2char):
temp_char = ''
elif temp_char_index <= 0:
temp_char = ''
elif temp_char_index == self.unknown_idx:
temp_char = ''
else:
temp_char = self.idx2char[temp_char_index]
chars.append(temp_char)
char_scores.append(res[3])
text = ''.join(chars)
texts.append(text)
scores.append(char_scores)
return texts, scores