Shortcuts

Source code for mmocr.models.textrecog.losses.mix_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmocr.models.builder import LOSSES


[docs]@LOSSES.register_module() class ABILoss(nn.Module): """Implementation of ABINet multiloss that allows mixing different types of losses with weights. Args: enc_weight (float): The weight of encoder loss. Defaults to 1.0. dec_weight (float): The weight of decoder loss. Defaults to 1.0. fusion_weight (float): The weight of fuser (aligner) loss. Defaults to 1.0. num_classes (int): Number of unique output language tokens. Returns: A dictionary whose key/value pairs are the losses of three modules. """ def __init__(self, enc_weight=1.0, dec_weight=1.0, fusion_weight=1.0, num_classes=37, **kwargs): assert isinstance(enc_weight, float) or isinstance(enc_weight, int) assert isinstance(dec_weight, float) or isinstance(dec_weight, int) assert isinstance(fusion_weight, float) or \ isinstance(fusion_weight, int) super().__init__() self.enc_weight = enc_weight self.dec_weight = dec_weight self.fusion_weight = fusion_weight self.num_classes = num_classes def _flatten(self, logits, target_lens): flatten_logits = torch.cat( [s[:target_lens[i]] for i, s in enumerate((logits))]) return flatten_logits def _ce_loss(self, logits, targets): targets_one_hot = F.one_hot(targets, self.num_classes) log_prob = F.log_softmax(logits, dim=-1) loss = -(targets_one_hot.to(log_prob.device) * log_prob).sum(dim=-1) return loss.mean() def _loss_over_iters(self, outputs, targets): """ Args: outputs (list[Tensor]): Each tensor has shape (N, T, C) where N is the batch size, T is the sequence length and C is the number of classes. targets_dicts (dict): The dictionary with at least `padded_targets` defined. """ iter_num = len(outputs) dec_outputs = torch.cat(outputs, dim=0) flatten_targets_iternum = targets.repeat(iter_num) return self._ce_loss(dec_outputs, flatten_targets_iternum)
[docs] def forward(self, outputs, targets_dict, img_metas=None): """ Args: outputs (dict): The output dictionary with at least one of ``out_enc``, ``out_dec`` and ``out_fusers`` specified. targets_dict (dict): The target dictionary containing the key ``padded_targets``, which represents target sequences in shape (batch_size, sequence_length). Returns: A loss dictionary with ``loss_visual``, ``loss_lang`` and ``loss_fusion``. Each should either be the loss tensor or ``0`` if the output of its corresponding module is not given. """ assert 'out_enc' in outputs or \ 'out_dec' in outputs or 'out_fusers' in outputs losses = {} target_lens = [len(t) for t in targets_dict['targets']] flatten_targets = torch.cat([t for t in targets_dict['targets']]) if outputs.get('out_enc', None): enc_input = self._flatten(outputs['out_enc']['logits'], target_lens) enc_loss = self._ce_loss(enc_input, flatten_targets) * self.enc_weight losses['loss_visual'] = enc_loss if outputs.get('out_decs', None): dec_logits = [ self._flatten(o['logits'], target_lens) for o in outputs['out_decs'] ] dec_loss = self._loss_over_iters(dec_logits, flatten_targets) * self.dec_weight losses['loss_lang'] = dec_loss if outputs.get('out_fusers', None): fusion_logits = [ self._flatten(o['logits'], target_lens) for o in outputs['out_fusers'] ] fusion_loss = self._loss_over_iters( fusion_logits, flatten_targets) * self.fusion_weight losses['loss_fusion'] = fusion_loss return losses
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.