DBHead¶
- class mmocr.models.textdet.DBHead(in_channels, with_bias=False, module_loss={'type': 'DBModuleLoss'}, postprocessor={'text_repr_type': 'quad', 'type': 'DBPostprocessor'}, init_cfg=[{'type': 'Kaiming', 'layer': 'Conv'}, {'type': 'Constant', 'layer': 'BatchNorm', 'val': 1.0, 'bias': 0.0001}])[源代码]¶
The class for DBNet head.
This was partially adapted from https://github.com/MhLiao/DB
- 参数
in_channels (int) – The number of input channels.
with_bias (bool) – Whether add bias in Conv2d layer. Defaults to False.
module_loss (dict) – Config of loss for dbnet. Defaults to
dict(type='DBModuleLoss')
postprocessor (dict) – Config of postprocessor for dbnet.
init_cfg (dict or list[dict], optional) – Initialization configs.
- 返回类型
- forward(img, data_samples=None, mode='predict')[源代码]¶
- 参数
img (Tensor) – Shape \((N, C, H, W)\).
data_samples (list[TextDetDataSample], optional) – A list of data samples. Defaults to None.
mode (str) –
Forward mode. It affects the return values. Options are “loss”, “predict” and “both”. Defaults to “predict”.
loss
: Run the full network and return the prob logits, threshold map and binary map.predict
: Run the binarzation part and return the prob map only.both
: Run the full network and return prob logits, threshold map, binary map and prob map.
- 返回
Its type depends on
mode
, read its docstring for details. Each has the shape of \((N, 4H, 4W)\).- 返回类型
Tensor or tuple(Tensor)
- loss(x, batch_data_samples)[源代码]¶
Perform forward propagation and loss calculation of the detection head on the features of the upstream network.
- loss_and_predict(x, batch_data_samples)[源代码]¶
Perform forward propagation of the head, then calculate loss and predictions from the features and data samples.
- 参数
x (tuple[Tensor]) – Features from FPN.
batch_data_samples (list[
DetDataSample
]) – Each item contains the meta information of each image and corresponding annotations.
- 返回
the return value is a tuple contains:
losses: (dict[str, Tensor]): A dictionary of loss components.
predictions (list[
InstanceData
]): Detection results of each image after the post process.
- 返回类型
- predict(x, batch_data_samples)[源代码]¶
Perform forward propagation of the detection head and predict detection results on the features of the upstream network.
- 参数
x (tuple[Tensor]) – Multi-level features from the upstream network, each is a 4D-tensor.
batch_data_samples (List[
DetDataSample
]) – The Data Samples. It usually includes information such as gt_instance, gt_panoptic_seg and gt_sem_seg.
- 返回
Detection results of each image after the post process.
- 返回类型
SampleList