Shortcuts

Source code for mmocr.models.textrecog.convertors.attn

# Copyright (c) OpenMMLab. All rights reserved.
import torch

import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from .base import BaseConvertor


[docs]@CONVERTORS.register_module() class AttnConvertor(BaseConvertor): """Convert between text, index and tensor for encoder-decoder based pipeline. Args: dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}. dict_file (None|str): Character dict file path. If not none, higher priority than dict_type. dict_list (None|list[str]): Character list. If not none, higher priority than dict_type, but lower than dict_file. with_unknown (bool): If True, add `UKN` token to class. max_seq_len (int): Maximum sequence length of label. lower (bool): If True, convert original string to lower case. start_end_same (bool): Whether use the same index for start and end token or not. Default: True. """ def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None, with_unknown=True, max_seq_len=40, lower=False, start_end_same=True, **kwargs): super().__init__(dict_type, dict_file, dict_list) assert isinstance(with_unknown, bool) assert isinstance(max_seq_len, int) assert isinstance(lower, bool) self.with_unknown = with_unknown self.max_seq_len = max_seq_len self.lower = lower self.start_end_same = start_end_same self.update_dict() def update_dict(self): start_end_token = '<BOS/EOS>' unknown_token = '<UKN>' padding_token = '<PAD>' # unknown self.unknown_idx = None if self.with_unknown: self.idx2char.append(unknown_token) self.unknown_idx = len(self.idx2char) - 1 # BOS/EOS self.idx2char.append(start_end_token) self.start_idx = len(self.idx2char) - 1 if not self.start_end_same: self.idx2char.append(start_end_token) self.end_idx = len(self.idx2char) - 1 # padding self.idx2char.append(padding_token) self.padding_idx = len(self.idx2char) - 1 # update char2idx self.char2idx = {} for idx, char in enumerate(self.idx2char): self.char2idx[char] = idx
[docs] def str2tensor(self, strings): """ Convert text-string into tensor. Args: strings (list[str]): ['hello', 'world'] Returns: dict (str: Tensor | list[tensor]): tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]), torch.Tensor([5,4,6,3,7])] padded_targets (Tensor(bsz * max_seq_len)) """ assert utils.is_type_list(strings, str) tensors, padded_targets = [], [] indexes = self.str2idx(strings) for index in indexes: tensor = torch.LongTensor(index) tensors.append(tensor) # target tensor for loss src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0) src_target[-1] = self.end_idx src_target[0] = self.start_idx src_target[1:-1] = tensor padded_target = (torch.ones(self.max_seq_len) * self.padding_idx).long() char_num = src_target.size(0) if char_num > self.max_seq_len: padded_target = src_target[:self.max_seq_len] else: padded_target[:char_num] = src_target padded_targets.append(padded_target) padded_targets = torch.stack(padded_targets, 0).long() return {'targets': tensors, 'padded_targets': padded_targets}
[docs] def tensor2idx(self, outputs, img_metas=None): """ Convert output tensor to text-index Args: outputs (tensor): model outputs with size: N * T * C img_metas (list[dict]): Each dict contains one image info. Returns: indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]] scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], [0.9,0.9,0.98,0.97,0.96]] """ batch_size = outputs.size(0) ignore_indexes = [self.padding_idx] indexes, scores = [], [] for idx in range(batch_size): seq = outputs[idx, :, :] max_value, max_idx = torch.max(seq, -1) str_index, str_score = [], [] output_index = max_idx.cpu().detach().numpy().tolist() output_score = max_value.cpu().detach().numpy().tolist() for char_index, char_score in zip(output_index, output_score): if char_index in ignore_indexes: continue if char_index == self.end_idx: break str_index.append(char_index) str_score.append(char_score) indexes.append(str_index) scores.append(str_score) return indexes, scores
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.