PSEModuleLoss¶
- class mmocr.models.textdet.PSEModuleLoss(weight_text=0.7, weight_kernel=0.3, loss_text={'type': 'MaskedSquareDiceLoss'}, loss_kernel={'type': 'MaskedSquareDiceLoss'}, ohem_ratio=3, reduction='mean', kernel_sample_type='adaptive', shrink_ratio=(1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4), max_shrink_dist=20)[源代码]¶
The class for implementing PSENet loss. This is partially adapted from https://github.com/whai362/PSENet.
PSENet: Shape Robust Text Detection with Progressive Scale Expansion Network.
- 参数
weight_text (float) – The weight of text loss. Defaults to 0.7.
weight_kernel (float) – The weight of text kernel. Defaults to 0.3.
loss_text (dict) – Loss type for text. Defaults to dict(‘MaskedSquareDiceLoss’).
loss_kernel (dict) – Loss type for kernel. Defaults to dict(‘MaskedSquareDiceLoss’).
ohem_ratio (int or float) – The negative/positive ratio in ohem. Defaults to 3.
reduction (str) – The way to reduce the loss. Defaults to ‘mean’. Options are ‘mean’ and ‘sum’.
kernel_sample_type (str) – The way to sample kernel. Defaults to adaptive. Options are ‘adaptive’ and ‘hard’.
shrink_ratio (tuple) – The ratio for shirinking text instances. Defaults to (1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4).
max_shrink_dist (int or float) – The maximum shrinking distance. Defaults to 20.
- 返回类型
- forward(preds, data_samples)[源代码]¶
Compute PSENet loss.
- 参数
preds (torch.Tensor) – Raw predictions from model with shape \((N, C, H, W)\).
data_samples (list[TextDetDataSample]) – The data samples.
- 返回
The dict for pse losses with loss_text, loss_kernel, loss_aggregation and loss_discrimination.
- 返回类型