加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Java > 正文

xavier_uniform/xavier_normal

发布时间:2020-12-15 07:33:32 所属栏目:Java 来源:网络整理
导读:import mathfrom torch.autograd import Variableimport torchimport torch.nn as nnimport warningswarnings.filterwarnings("ignore")def _calculate_fan_in_and_fan_out(tensor): print("***********_calculate_fan_in_and_fan_out****************") dim
import math
from torch.autograd import Variable
import torch
import torch.nn as nn


import warnings
warnings.filterwarnings("ignore")

def _calculate_fan_in_and_fan_out(tensor):
    print("***********_calculate_fan_in_and_fan_out****************")
    dimensions = tensor.dim()
    print("dimensions",dimensions)
    if dimensions < 2:
        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")

    if dimensions == 2:  # Linear
        fan_in = tensor.size(1)
        fan_out = tensor.size(0)
        print("fan_in",fan_in)
        print("fan_out",fan_out)
    else:
        num_input_fmaps = tensor.size(1)
        num_output_fmaps = tensor.size(0)
        print("num_input_fmaps",num_input_fmaps)
        print("num_output_fmaps",num_output_fmaps)
        receptive_field_size = 1
        if tensor.dim() > 2:
            receptive_field_size = tensor[0][0].numel()
            print("receptive_field_size",receptive_field_size)

        fan_in = num_input_fmaps * receptive_field_size
        fan_out = num_output_fmaps * receptive_field_size

    return fan_in,fan_out


def xavier_uniform(tensor,gain=1):
    print("****************xavier_uniform*****************")

    fan_in,fan_out = _calculate_fan_in_and_fan_out(tensor)
    print("fan_in",fan_in)
    print("fan_out",fan_out)

    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    print("std",std)
    a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    print("a",a)
    return tensor.uniform_(-a,a)

def xavier_normal(tensor,gain=1):
    print("****************xavier_normal*****************")

    fan_in,std)

    return tensor.normal_(0,std)


w = torch.Tensor(3,5)
xavier_uniform=xavier_uniform(tensor=w,gain=1)
print("xavier_uniform",xavier_uniform)

xavier_normal=xavier_normal(tensor=w,gain=1)
print("xavier_normal",xavier_normal)



‘‘‘

****************xavier_uniform*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.5
a 0.8660254037844386
xavier_uniform tensor([[-0.0043,-0.6705,-0.4981,-0.6935,0.3967],[ 0.3643,0.2465,0.6906,-0.2256,-0.7046],[ 0.6660,0.7381,0.5887,0.0423,0.2840]])
****************xavier_normal*****************
***********_calculate_fan_in_and_fan_out****************
dimensions 2
fan_in 5
fan_out 3
fan_in 5
fan_out 3
std 0.5
xavier_normal tensor([[ 0.6554,-0.3533,-0.2101,-0.0362,0.3919],[ 0.4505,-0.8219,0.5489,0.7568,-0.5317],[-0.2396,0.1093,-0.3372,-0.1136,0.4452]])




‘‘‘

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读