Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.textrecog.losses.ce_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn

from mmocr.models.builder import LOSSES


[docs]@LOSSES.register_module() class CELoss(nn.Module): """Implementation of loss module for encoder-decoder based text recognition method with CrossEntropy loss. Args: ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. reduction (str): Specifies the reduction to apply to the output, should be one of the following: ('none', 'mean', 'sum'). ignore_first_char (bool): Whether to ignore the first token in target ( usually the start token). If ``True``, the last token of the output sequence will also be removed to be aligned with the target length. """ def __init__(self, ignore_index=-1, reduction='none', ignore_first_char=False): super().__init__() assert isinstance(ignore_index, int) assert isinstance(reduction, str) assert reduction in ['none', 'mean', 'sum'] assert isinstance(ignore_first_char, bool) self.loss_ce = nn.CrossEntropyLoss( ignore_index=ignore_index, reduction=reduction) self.ignore_first_char = ignore_first_char def format(self, outputs, targets_dict): targets = targets_dict['padded_targets'] if self.ignore_first_char: targets = targets[:, 1:].contiguous() outputs = outputs[:, :-1, :] outputs = outputs.permute(0, 2, 1).contiguous() return outputs, targets
[docs] def forward(self, outputs, targets_dict, img_metas=None): """ Args: outputs (Tensor): A raw logit tensor of shape :math:`(N, T, C)`. targets_dict (dict): A dict with a key ``padded_targets``, which is a tensor of shape :math:`(N, T)`. Each element is the index of a character. img_metas (None): Unused. Returns: dict: A loss dict with the key ``loss_ce``. """ outputs, targets = self.format(outputs, targets_dict) loss_ce = self.loss_ce(outputs, targets.to(outputs.device)) losses = dict(loss_ce=loss_ce) return losses
[docs]@LOSSES.register_module() class SARLoss(CELoss): """Implementation of loss module in `SAR. <https://arxiv.org/abs/1811.00751>`_. Args: ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. reduction (str): Specifies the reduction to apply to the output, should be one of the following: ("none", "mean", "sum"). Warning: SARLoss assumes that the first input token is always `<SOS>`. """ def __init__(self, ignore_index=-1, reduction='mean', **kwargs): super().__init__(ignore_index, reduction) def format(self, outputs, targets_dict): targets = targets_dict['padded_targets'] # targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...] # outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...] # ignore first index of target in loss calculation targets = targets[:, 1:].contiguous() # ignore last index of outputs to be in same seq_len with targets outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous() return outputs, targets
[docs]@LOSSES.register_module() class TFLoss(CELoss): """Implementation of loss module for transformer. Args: ignore_index (int, optional): The character index to be ignored in loss computation. reduction (str): Type of reduction to apply to the output, should be one of the following: ("none", "mean", "sum"). flatten (bool): Whether to flatten the vectors for loss computation. Warning: TFLoss assumes that the first input token is always `<SOS>`. """ def __init__(self, ignore_index=-1, reduction='none', flatten=True, **kwargs): super().__init__(ignore_index, reduction) assert isinstance(flatten, bool) self.flatten = flatten def format(self, outputs, targets_dict): outputs = outputs[:, :-1, :].contiguous() targets = targets_dict['padded_targets'] targets = targets[:, 1:].contiguous() if self.flatten: outputs = outputs.view(-1, outputs.size(-1)) targets = targets.view(-1) else: outputs = outputs.permute(0, 2, 1).contiguous() return outputs, targets
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.