Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.models.common.losses.focal_loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class FocalLoss(nn.Module):
"""Multi-class Focal loss implementation.
Args:
gamma (float): The larger the gamma, the smaller
the loss weight of easier samples.
weight (float): A manual rescaling weight given to each
class.
ignore_index (int): Specifies a target value that is ignored
and does not contribute to the input gradient.
"""
def __init__(self, gamma=2, weight=None, ignore_index=-100):
super().__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
[docs] def forward(self, input, target):
logit = F.log_softmax(input, dim=1)
pt = torch.exp(logit)
logit = (1 - pt)**self.gamma * logit
loss = F.nll_loss(
logit, target, self.weight, ignore_index=self.ignore_index)
return loss