xavier_uniform/xavier_normal

import math
from torch.autograd import Variable
import torch
import torch.nn as nn


import warnings
warnings.filterwarnings("ignore")

def _calculate_fan_in_and_fan_out(tensor):
    print("***********_calculate_fan_in_and_fan_out****************")
    dimensions = tensor.dim()
    print("dimensions",dimensions)
    if dimensions < 2:
        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")

    if dimensions == 2:  # Linear
        fan_in = tensor.size(1)
        fan_out = tensor.size(0)
        print("fan_in",fan_in)
        print("fan_out",fan_out)
    else:
        num_input_fmaps = tensor.size(1)
        num_output_fmaps = tensor.size(0)
        print("num_input_fmaps",num_input_fmaps)
        print("num_output_fmaps", num_output_fmaps)
        receptive_field_size = 1
        if tensor.dim() > 2:
            receptive_field_size = tensor[0][0].numel()
            print("receptive_field_size", receptive_field_size)

        fan_in = num_input_fmaps * receptive_field_size
        fan_out = num_output_fmaps * receptive_field_size

    return fan_in, fan_out


def xavier_uniform(tensor, gain=1):
    print("****************xavier_uniform*****************")

    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    print("fan_in", fan_in)
    print("fan_out", fan_out)

    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    print("std",std)
    a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    print("a",a)
    return tensor.uniform_(-a, a)

def xavier_normal(tensor, gain=1):
    print("****************xavier_normal*****************")

    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    print("fan_in", fan_in)
    print("fan_out", fan_out)

    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    print("std", std)

    return tensor.normal_(0, std)


w = torch.Tensor(3,5)
xavier_uniform=xavier_uniform(tensor=w,gain=1)
print("xavier_uniform",xavier_uniform)

xavier_normal=xavier_normal(tensor=w,gain=1)
print("xavier_normal",xavier_normal)



'''

****************xavier_uniform*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.5
a 0.8660254037844386
xavier_uniform tensor([[-0.0043, -0.6705, -0.4981, -0.6935,  0.3967],
        [ 0.3643,  0.2465,  0.6906, -0.2256, -0.7046],
        [ 0.6660,  0.7381,  0.5887,  0.0423,  0.2840]])
****************xavier_normal*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.5
xavier_normal tensor([[ 0.6554, -0.3533, -0.2101, -0.0362,  0.3919],
        [ 0.4505, -0.8219,  0.5489,  0.7568, -0.5317],
        [-0.2396,  0.1093, -0.3372, -0.1136,  0.4452]])




'''
知识兔

  

计算机