
mmocr.models.textdet.heads.textsnake_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn

from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample

[文档]@MODELS.register_module() class TextSnakeHead(BaseTextDetHead): """The class for TextSnake head: TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes. TextSnake: `A Flexible Representation for Detecting Text of Arbitrary Shapes <>`_. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. downsample_ratio (float): Downsample ratio. module_loss (dict): Configuration dictionary for loss type. Defaults to ``dict(type='TextSnakeModuleLoss')``. postprocessor (dict): Config of postprocessor for TextSnake. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__( self, in_channels: int, out_channels: int = 5, downsample_ratio: float = 1.0, module_loss: Dict = dict(type='TextSnakeModuleLoss'), postprocessor: Dict = dict( type='TextSnakePostprocessor', text_repr_type='poly'), init_cfg: Optional[Union[Dict, List[Dict]]] = dict( type='Normal', override=dict(name='out_conv'), mean=0, std=0.01) ) -> None: super().__init__( module_loss=module_loss, postprocessor=postprocessor, init_cfg=init_cfg) assert isinstance(in_channels, int) assert isinstance(out_channels, int) self.in_channels = in_channels self.out_channels = out_channels self.downsample_ratio = downsample_ratio self.out_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0)
[文档] def forward(self, inputs: torch.Tensor, data_samples: Optional[List[TextDetDataSample]] = None ) -> Dict: """ Args: inputs (torch.Tensor): Shape :math:`(N, C_{in}, H, W)`, where :math:`C_{in}` is ``in_channels``. :math:`H` and :math:`W` should be the same as the input of backbone. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. Returns: Tensor: A tensor of shape :math:`(N, 5, H, W)`, where the five channels represent [0]: text score, [1]: center score, [2]: sin, [3] cos, [4] radius, respectively. """ outputs = self.out_conv(inputs) return outputs
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.