Note
You are reading the documentation for MMOCR 0.x, which will soon be deprecated by the end of 2022. We recommend you upgrade to MMOCR 1.0 to enjoy fruitful new features and better performance brought by OpenMMLab 2.0. Check out the maintenance plan, changelog, code and documentation of MMOCR 1.0 for more details.
Source code for mmocr.utils.model
# Copyright (c) OpenMMLab. All rights reserved.
import torch
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
"""A general BatchNorm layer without input dimension check.
Reproduced from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
is `_check_input_dim` that is designed for tensor sanity checks.
The check has been bypassed in this class for the convenience of converting
SyncBatchNorm.
"""
def _check_input_dim(self, input):
return
[docs]def revert_sync_batchnorm(module):
"""Helper function to convert all `SyncBatchNorm` layers in the model to
`BatchNormXd` layers.
Adapted from @kapily's work:
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
Args:
module (nn.Module): The module containing `SyncBatchNorm` layers.
Returns:
module_output: The converted module with `BatchNormXd` layers.
"""
module_output = module
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
module_output = _BatchNormXd(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output