Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.textrecog.losses.mix_loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmocr.models.builder import LOSSES
[docs]@LOSSES.register_module()
class ABILoss(nn.Module):
"""Implementation of ABINet multiloss that allows mixing different types of
losses with weights.
Args:
enc_weight (float): The weight of encoder loss. Defaults to 1.0.
dec_weight (float): The weight of decoder loss. Defaults to 1.0.
fusion_weight (float): The weight of fuser (aligner) loss.
Defaults to 1.0.
num_classes (int): Number of unique output language tokens.
Returns:
A dictionary whose key/value pairs are the losses of three modules.
"""
def __init__(self,
enc_weight=1.0,
dec_weight=1.0,
fusion_weight=1.0,
num_classes=37,
**kwargs):
assert isinstance(enc_weight, float) or isinstance(enc_weight, int)
assert isinstance(dec_weight, float) or isinstance(dec_weight, int)
assert isinstance(fusion_weight, float) or \
isinstance(fusion_weight, int)
super().__init__()
self.enc_weight = enc_weight
self.dec_weight = dec_weight
self.fusion_weight = fusion_weight
self.num_classes = num_classes
def _flatten(self, logits, target_lens):
flatten_logits = torch.cat(
[s[:target_lens[i]] for i, s in enumerate((logits))])
return flatten_logits
def _ce_loss(self, logits, targets):
targets_one_hot = F.one_hot(targets, self.num_classes)
log_prob = F.log_softmax(logits, dim=-1)
loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1)
return loss.mean()
def _loss_over_iters(self, outputs, targets):
"""
Args:
outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is
the batch size, T is the sequence length and C is the number of
classes.
targets_dicts (dict): The dictionary with at least `padded_targets`
defined.
"""
iter_num = len(outputs)
dec_outputs = torch.cat(outputs, dim=0)
flatten_targets_iternum = targets.repeat(iter_num)
return self._ce_loss(dec_outputs, flatten_targets_iternum)
[docs] def forward(self, outputs, targets_dict, img_metas=None):
"""
Args:
outputs (dict): The output dictionary with at least one of
``out_enc``, ``out_dec`` and ``out_fusers`` specified.
targets_dict (dict): The target dictionary containing the key
``padded_targets``, which represents target sequences in
shape (batch_size, sequence_length).
Returns:
A loss dictionary with ``loss_visual``, ``loss_lang`` and
``loss_fusion``. Each should either be the loss tensor or ``0`` if
the output of its corresponding module is not given.
"""
assert 'out_enc' in outputs or \
'out_dec' in outputs or 'out_fusers' in outputs
losses = {}
target_lens = [len(t) for t in targets_dict['targets']]
flatten_targets = torch.cat([t for t in targets_dict['targets']])
if outputs.get('out_enc', None):
enc_input = self._flatten(outputs['out_enc']['logits'],
target_lens)
enc_loss = self._ce_loss(enc_input,
flatten_targets) * self.enc_weight
losses['loss_visual'] = enc_loss
if outputs.get('out_decs', None):
dec_logits = [
self._flatten(o['logits'], target_lens)
for o in outputs['out_decs']
]
dec_loss = self._loss_over_iters(dec_logits,
flatten_targets) * self.dec_weight
losses['loss_lang'] = dec_loss
if outputs.get('out_fusers', None):
fusion_logits = [
self._flatten(o['logits'], target_lens)
for o in outputs['out_fusers']
]
fusion_loss = self._loss_over_iters(
fusion_logits, flatten_targets) * self.fusion_weight
losses['loss_fusion'] = fusion_loss
return losses