Shortcuts

Note

You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.

Source code for mmocr.models.ner.losses.masked_focal_loss

# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn

from mmocr.models.builder import LOSSES
from mmocr.models.common.losses.focal_loss import FocalLoss


[docs]@LOSSES.register_module() class MaskedFocalLoss(nn.Module): """The implementation of masked focal loss. The mask has 1 for real tokens and 0 for padding tokens, which only keep active parts of the focal loss Args: num_labels (int): Number of classes in labels. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. """ def __init__(self, num_labels=None, ignore_index=0): super().__init__() self.num_labels = num_labels self.criterion = FocalLoss(ignore_index=ignore_index)
[docs] def forward(self, logits, img_metas): '''Loss forword. Args: logits: Model output with shape [N, C]. img_metas (dict): A dict containing the following keys: - img (list]): This parameter is reserved. - labels (list[int]): The labels for each word of the sequence. - texts (list): The words of the sequence. - input_ids (list): The ids for each word of the sequence. - attention_mask (list): The mask for each word of the sequence. The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. - token_type_ids (list): The tokens for each word of the sequence. ''' labels = img_metas['labels'] attention_masks = img_metas['attention_masks'] # Only keep active parts of the loss if attention_masks is not None: active_loss = attention_masks.view(-1) == 1 active_logits = logits.view(-1, self.num_labels)[active_loss] active_labels = labels.view(-1)[active_loss] loss = self.criterion(active_logits, active_labels) else: loss = self.criterion( logits.view(-1, self.num_labels), labels.view(-1)) return {'loss_cls': loss}
Read the Docs v: v0.6.3
Versions
latest
stable
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.