Shortcuts

Source code for mmocr.models.textdet.heads.pan_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch
import torch.nn as nn

from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import check_argument
from .base import BaseTextDetHead


[docs]@MODELS.register_module() class PANHead(BaseTextDetHead): """The class for PANet head. Args: in_channels (list[int]): A list of 4 numbers of input channels. hidden_dim (int): The hidden dimension of the first convolutional layer. out_channel (int): Number of output channels. module_loss (dict): Configuration dictionary for loss type. Defaults to dict(type='PANModuleLoss') postprocessor (dict): Config of postprocessor for PANet. Defaults to dict(type='PANPostprocessor', text_repr_type='poly'). init_cfg (list[dict]): Initialization configs. Defaults to [dict(type='Normal', mean=0, std=0.01, layer='Conv2d'), dict(type='Constant', val=1, bias=0, layer='BN')] """ def __init__( self, in_channels: List[int], hidden_dim: int, out_channel: int, module_loss=dict(type='PANModuleLoss'), postprocessor=dict(type='PANPostprocessor', text_repr_type='poly'), init_cfg=[ dict(type='Normal', mean=0, std=0.01, layer='Conv2d'), dict(type='Constant', val=1, bias=0, layer='BN') ] ) -> None: super().__init__( module_loss=module_loss, postprocessor=postprocessor, init_cfg=init_cfg) assert check_argument.is_type_list(in_channels, int) assert isinstance(out_channel, int) assert isinstance(hidden_dim, int) in_channels = sum(in_channels) self.conv1 = nn.Conv2d( in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(hidden_dim) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( hidden_dim, out_channel, kernel_size=1, stride=1, padding=0)
[docs] def forward(self, inputs: torch.Tensor, data_samples: Optional[List[TextDetDataSample]] = None ) -> torch.Tensor: r"""PAN head forward. Args: inputs (list[Tensor] | Tensor): Each tensor has the shape of :math:`(N, C_i, W, H)`, where :math:`\sum_iC_i=C_{in}` and :math:`C_{in}` is ``input_channels``. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. Returns: Tensor: A tensor of shape :math:`(N, C_{out}, W, H)` where :math:`C_{out}` is ``output_channels``. """ if isinstance(inputs, tuple): outputs = torch.cat(inputs, dim=1) else: outputs = inputs outputs = self.conv1(outputs) outputs = self.relu1(self.bn1(outputs)) outputs = self.conv2(outputs) return outputs
Read the Docs v: dev-1.x
Versions
latest
stable
v1.0.1
v1.0.0
0.x
v0.6.3
v0.6.2
v0.6.1
v0.6.0
v0.5.0
v0.4.1
v0.4.0
v0.3.0
v0.2.1
v0.2.0
v0.1.0
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.