Shortcuts

mmocr.models.textrecog.preprocessors.tps_preprocessor 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule

from mmocr.registry import MODELS
from .base import BasePreprocessor


class TPStransform(nn.Module):
    """Implement TPS transform.

    This was partially adapted from https://github.com/ayumiymk/aster.pytorch

    Args:
        output_image_size (tuple[int, int]): The size of the output image.
            Defaults to (32, 128).
        num_control_points (int): The number of control points. Defaults to 20.
        margins (tuple[float, float]): The margins for control points to the
            top and down side of the image. Defaults to [0.05, 0.05].
    """

    def __init__(self,
                 output_image_size: Tuple[int, int] = (32, 100),
                 num_control_points: int = 20,
                 margins: Tuple[float, float] = [0.05, 0.05]) -> None:
        super().__init__()
        self.output_image_size = output_image_size
        self.num_control_points = num_control_points
        self.margins = margins
        self.target_height, self.target_width = output_image_size

        # build output control points
        target_control_points = self._build_output_control_points(
            num_control_points, margins)
        N = num_control_points

        # create padded kernel matrix
        forward_kernel = torch.zeros(N + 3, N + 3)
        target_control_partial_repr = self._compute_partial_repr(
            target_control_points, target_control_points)
        forward_kernel[:N, :N].copy_(target_control_partial_repr)
        forward_kernel[:N, -3].fill_(1)
        forward_kernel[-3, :N].fill_(1)
        forward_kernel[:N, -2:].copy_(target_control_points)
        forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))

        # compute inverse matrix
        inverse_kernel = torch.inverse(forward_kernel).contiguous()

        # create target coordinate matrix
        HW = self.target_height * self.target_width
        tgt_coord = list(
            itertools.product(
                range(self.target_height), range(self.target_width)))
        tgt_coord = torch.Tensor(tgt_coord)
        Y, X = tgt_coord.split(1, dim=1)
        Y = Y / (self.target_height - 1)
        X = X / (self.target_width - 1)
        tgt_coord = torch.cat([X, Y], dim=1)
        tgt_coord_partial_repr = self._compute_partial_repr(
            tgt_coord, target_control_points)
        tgt_coord_repr = torch.cat(
            [tgt_coord_partial_repr,
             torch.ones(HW, 1), tgt_coord], dim=1)

        # register precomputed matrices
        self.register_buffer('inverse_kernel', inverse_kernel)
        self.register_buffer('padding_matrix', torch.zeros(3, 2))
        self.register_buffer('target_coordinate_repr', tgt_coord_repr)
        self.register_buffer('target_control_points', target_control_points)

    def forward(self, input: torch.Tensor,
                source_control_points: torch.Tensor) -> torch.Tensor:
        """Forward function of the TPS block.

        Args:
            input (Tensor): The input image.
            source_control_points (Tensor): The control points of the source
                image of shape (N, self.num_control_points, 2).
        Returns:
            Tensor: The output image after TPS transform.
        """
        assert source_control_points.ndimension() == 3
        assert source_control_points.size(1) == self.num_control_points
        assert source_control_points.size(2) == 2
        batch_size = source_control_points.size(0)

        Y = torch.cat([
            source_control_points,
            self.padding_matrix.expand(batch_size, 3, 2)
        ], 1)
        mapping_matrix = torch.matmul(self.inverse_kernel, Y)
        source_coordinate = torch.matmul(self.target_coordinate_repr,
                                         mapping_matrix)

        grid = source_coordinate.view(-1, self.target_height,
                                      self.target_width, 2)
        grid = torch.clamp(grid, 0, 1)
        grid = 2.0 * grid - 1.0
        output_maps = self._grid_sample(input, grid, canvas=None)
        return output_maps

    def _grid_sample(self,
                     input: torch.Tensor,
                     grid: torch.Tensor,
                     canvas: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Sample the input image at the given grid.

        Args:
            input (Tensor): The input image.
            grid (Tensor): The grid to sample the input image.
            canvas (Optional[Tensor]): The canvas to store the output image.
        Returns:
            Tensor: The sampled image.
        """
        output = F.grid_sample(input, grid, align_corners=True)
        if canvas is None:
            return output
        else:
            input_mask = input.data.new(input.size()).fill_(1)
            output_mask = F.grid_sample(input_mask, grid, align_corners=True)
            padded_output = output * output_mask + canvas * (1 - output_mask)
            return padded_output

    def _compute_partial_repr(self, input_points: torch.Tensor,
                              control_points: torch.Tensor) -> torch.Tensor:
        """Compute the partial representation matrix.

        Args:
            input_points (Tensor): The input points.
            control_points (Tensor): The control points.
        Returns:
            Tensor: The partial representation matrix.
        """
        N = input_points.size(0)
        M = control_points.size(0)
        pairwise_diff = input_points.view(N, 1, 2) - control_points.view(
            1, M, 2)
        pairwise_diff_square = pairwise_diff * pairwise_diff
        pairwise_dist = pairwise_diff_square[:, :,
                                             0] + pairwise_diff_square[:, :, 1]
        repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
        mask = repr_matrix != repr_matrix
        repr_matrix.masked_fill_(mask, 0)
        return repr_matrix

    # output_ctrl_pts are specified, according to our task.
    def _build_output_control_points(self, num_control_points: torch.Tensor,
                                     margins: Tuple[float,
                                                    float]) -> torch.Tensor:
        """Build the output control points.

        The output points will be fix at
        top and down side of the image.
        Args:
            num_control_points (Tensor): The number of control points.
            margins (Tuple[float, float]): The margins for control points to
                the top and down side of the image.
        Returns:
            Tensor: The output control points.
        """
        margin_x, margin_y = margins
        num_ctrl_pts_per_side = num_control_points // 2
        ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x,
                                 num_ctrl_pts_per_side)
        ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
        ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
        output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
                                             axis=0)
        output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
        return output_ctrl_pts


[文档]@MODELS.register_module() class STN(BasePreprocessor): """Implement STN module in ASTER: An Attentional Scene Text Recognizer with Flexible Rectification (https://ieeexplore.ieee.org/abstract/document/8395027/) Args: in_channels (int): The number of input channels. resized_image_size (Tuple[int, int]): The resized image size. The input image will be downsampled to have a better recitified result. output_image_size: The size of the output image for TPS. Defaults to (32, 100). num_control_points: The number of control points. Defaults to 20. margins: The margins for control points to the top and down side of the image for TPS. Defaults to [0.05, 0.05]. """ def __init__(self, in_channels: int, resized_image_size: Tuple[int, int] = (32, 64), output_image_size: Tuple[int, int] = (32, 100), num_control_points: int = 20, margins: Tuple[float, float] = [0.05, 0.05], init_cfg: Optional[Union[Dict, List[Dict]]] = [ dict(type='Xavier', layer='Conv2d'), dict(type='Constant', val=1, layer='BatchNorm2d'), ]): super().__init__(init_cfg=init_cfg) self.resized_image_size = resized_image_size self.num_control_points = num_control_points self.tps = TPStransform(output_image_size, num_control_points, margins) self.stn_convnet = nn.Sequential( ConvModule(in_channels, 32, 3, 1, 1, norm_cfg=dict(type='BN')), nn.MaxPool2d(kernel_size=2, stride=2), ConvModule(32, 64, 3, 1, 1, norm_cfg=dict(type='BN')), nn.MaxPool2d(kernel_size=2, stride=2), ConvModule(64, 128, 3, 1, 1, norm_cfg=dict(type='BN')), nn.MaxPool2d(kernel_size=2, stride=2), ConvModule(128, 256, 3, 1, 1, norm_cfg=dict(type='BN')), nn.MaxPool2d(kernel_size=2, stride=2), ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')), nn.MaxPool2d(kernel_size=2, stride=2), ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')), ) self.stn_fc1 = nn.Sequential( nn.Linear(2 * 256, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True)) self.stn_fc2 = nn.Linear(512, num_control_points * 2) self.init_stn(self.stn_fc2)
[文档] def init_stn(self, stn_fc2: nn.Linear) -> None: """Initialize the output linear layer of stn, so that the initial source point will be at the top and down side of the image, which will help to optimize. Args: stn_fc2 (nn.Linear): The output linear layer of stn. """ margin = 0.01 sampling_num_per_side = int(self.num_control_points / 2) ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side) ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin) ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) stn_fc2.weight.data.zero_() stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
[文档] def forward(self, img: torch.Tensor) -> torch.Tensor: """Forward function of STN. Args: img (Tensor): The input image tensor. Returns: Tensor: The rectified image tensor. """ resize_img = F.interpolate( img, self.resized_image_size, mode='bilinear', align_corners=True) points = self.stn_convnet(resize_img) batch_size, _, _, _ = points.size() points = points.view(batch_size, -1) img_feat = self.stn_fc1(points) points = self.stn_fc2(0.1 * img_feat) points = points.view(-1, self.num_control_points, 2) transformd_image = self.tps(img, points) return transformd_image
Read the Docs v: latest
Versions
latest
stable
0.x
dev-1.x
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.