Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.ner.losses.masked_focal_loss
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.models.common.losses.focal_loss import FocalLoss
[docs]@LOSSES.register_module()
class MaskedFocalLoss(nn.Module):
"""The implementation of masked focal loss.
The mask has 1 for real tokens and 0 for padding tokens,
which only keep active parts of the focal loss
Args:
num_labels (int): Number of classes in labels.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient.
"""
def __init__(self, num_labels=None, ignore_index=0):
super().__init__()
self.num_labels = num_labels
self.criterion = FocalLoss(ignore_index=ignore_index)
[docs] def forward(self, logits, img_metas):
'''Loss forword.
Args:
logits: Model output with shape [N, C].
img_metas (dict): A dict containing the following keys:
- img (list]): This parameter is reserved.
- labels (list[int]): The labels for each word
of the sequence.
- texts (list): The words of the sequence.
- input_ids (list): The ids for each word of
the sequence.
- attention_mask (list): The mask for each word
of the sequence. The mask has 1 for real tokens
and 0 for padding tokens. Only real tokens are
attended to.
- token_type_ids (list): The tokens for each word
of the sequence.
'''
labels = img_metas['labels']
attention_masks = img_metas['attention_masks']
# Only keep active parts of the loss
if attention_masks is not None:
active_loss = attention_masks.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = self.criterion(active_logits, active_labels)
else:
loss = self.criterion(
logits.view(-1, self.num_labels), labels.view(-1))
return {'loss_cls': loss}