Shortcuts

Source code for mmocr.models.common.losses.focal_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class FocalLoss(nn.Module): """Multi-class Focal loss implementation. Args: gamma (float): The larger the gamma, the smaller the loss weight of easier samples. weight (float): A manual rescaling weight given to each class. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. """ def __init__(self, gamma=2, weight=None, ignore_index=-100): super().__init__() self.gamma = gamma self.weight = weight self.ignore_index = ignore_index
[docs] def forward(self, input, target): logit = F.log_softmax(input, dim=1) pt = torch.exp(logit) logit = (1 - pt)**self.gamma * logit loss = F.nll_loss( logit, target, self.weight, ignore_index=self.ignore_index) return loss
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.