Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.ner.losses.masked_cross_entropy_loss

# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from torch.nn import CrossEntropyLoss

from mmocr.models.builder import LOSSES


[docs]@LOSSES.register_module() class MaskedCrossEntropyLoss(nn.Module): """The implementation of masked cross entropy loss. The mask has 1 for real tokens and 0 for padding tokens, which only keep active parts of the cross entropy loss. Args: num_labels (int): Number of classes in labels. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. """ def __init__(self, num_labels=None, ignore_index=0): super().__init__() self.num_labels = num_labels self.criterion = CrossEntropyLoss(ignore_index=ignore_index)
[docs] def forward(self, logits, img_metas): '''Loss forword. Args: logits: Model output with shape [N, C]. img_metas (dict): A dict containing the following keys: - img (list]): This parameter is reserved. - labels (list[int]): The labels for each word of the sequence. - texts (list): The words of the sequence. - input_ids (list): The ids for each word of the sequence. - attention_mask (list): The mask for each word of the sequence. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. - token_type_ids (list): The tokens for each word of the sequence. ''' labels = img_metas['labels'] attention_masks = img_metas['attention_masks'] # Only keep active parts of the loss if attention_masks is not None: active_loss = attention_masks.view(-1) == 1 active_logits = logits.view(-1, self.num_labels)[active_loss] active_labels = labels.view(-1)[active_loss] loss = self.criterion(active_logits, active_labels) else: loss = self.criterion( logits.view(-1, self.num_labels), labels.view(-1)) return {'loss_cls': loss}
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.