# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch.nn as nn
[docs]def c2_xavier_fill(module: nn.Module) -> None:
"""
Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
# Caffe2 implementation of XavierFill in fact
# corresponds to kaiming_uniform_ in PyTorch
# pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`.
nn.init.kaiming_uniform_(module.weight, a=1)
if module.bias is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
# torch.Tensor]`.
nn.init.constant_(module.bias, 0)
[docs]def c2_msra_fill(module: nn.Module) -> None:
"""
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
Also initializes `module.bias` to 0.
Args:
module (torch.nn.Module): module to initialize.
"""
# pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`.
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
# torch.Tensor]`.
nn.init.constant_(module.bias, 0)