Shortcuts

Source code for mmocr.models.common.losses.dice_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from mmocr.models.builder import LOSSES


[docs]@LOSSES.register_module() class DiceLoss(nn.Module): def __init__(self, eps=1e-6): super().__init__() assert isinstance(eps, float) self.eps = eps
[docs] def forward(self, pred, target, mask=None): pred = pred.contiguous().view(pred.size()[0], -1) target = target.contiguous().view(target.size()[0], -1) if mask is not None: mask = mask.contiguous().view(mask.size()[0], -1) pred = pred * mask target = target * mask a = torch.sum(pred * target) b = torch.sum(pred) c = torch.sum(target) d = (2 * a) / (b + c + self.eps) return 1 - d
Read the Docs v: v0.4.0
Versions
latest
stable
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.