import math
from torch.autograd import Variable
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
def _calculate_fan_in_and_fan_out(tensor):
print("***********_calculate_fan_in_and_fan_out****************")
dimensions = tensor.dim()
print("dimensions",dimensions)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor.size(1)
fan_out = tensor.size(0)
print("fan_in",fan_in)
print("fan_out",fan_out)
else:
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
print("num_input_fmaps",num_input_fmaps)
print("num_output_fmaps", num_output_fmaps)
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
print("receptive_field_size", receptive_field_size)
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def xavier_uniform(tensor, gain=1):
print("****************xavier_uniform*****************")
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
print("fan_in", fan_in)
print("fan_out", fan_out)
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
print("std",std)
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
print("a",a)
return tensor.uniform_(-a, a)
def xavier_normal(tensor, gain=1):
print("****************xavier_normal*****************")
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
print("fan_in", fan_in)
print("fan_out", fan_out)
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
print("std", std)
return tensor.normal_(0, std)
w = torch.Tensor(3,5)
xavier_uniform=xavier_uniform(tensor=w,gain=1)
print("xavier_uniform",xavier_uniform)
xavier_normal=xavier_normal(tensor=w,gain=1)
print("xavier_normal",xavier_normal)
'''
****************xavier_uniform*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.5
a 0.8660254037844386
xavier_uniform tensor([[-0.0043, -0.6705, -0.4981, -0.6935, 0.3967],
[ 0.3643, 0.2465, 0.6906, -0.2256, -0.7046],
[ 0.6660, 0.7381, 0.5887, 0.0423, 0.2840]])
****************xavier_normal*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.5
xavier_normal tensor([[ 0.6554, -0.3533, -0.2101, -0.0362, 0.3919],
[ 0.4505, -0.8219, 0.5489, 0.7568, -0.5317],
[-0.2396, 0.1093, -0.3372, -0.1136, 0.4452]])
'''
知识兔