xavier_uniform/xavier_normal
发布时间:2020-12-15 07:33:32 所属栏目:Java 来源:网络整理
导读:import mathfrom torch.autograd import Variableimport torchimport torch.nn as nnimport warningswarnings.filterwarnings("ignore")def _calculate_fan_in_and_fan_out(tensor): print("***********_calculate_fan_in_and_fan_out****************") dim
import math from torch.autograd import Variable import torch import torch.nn as nn import warnings warnings.filterwarnings("ignore") def _calculate_fan_in_and_fan_out(tensor): print("***********_calculate_fan_in_and_fan_out****************") dimensions = tensor.dim() print("dimensions",dimensions) if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") if dimensions == 2: # Linear fan_in = tensor.size(1) fan_out = tensor.size(0) print("fan_in",fan_in) print("fan_out",fan_out) else: num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) print("num_input_fmaps",num_input_fmaps) print("num_output_fmaps",num_output_fmaps) receptive_field_size = 1 if tensor.dim() > 2: receptive_field_size = tensor[0][0].numel() print("receptive_field_size",receptive_field_size) fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in,fan_out def xavier_uniform(tensor,gain=1): print("****************xavier_uniform*****************") fan_in,fan_out = _calculate_fan_in_and_fan_out(tensor) print("fan_in",fan_in) print("fan_out",fan_out) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) print("std",std) a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation print("a",a) return tensor.uniform_(-a,a) def xavier_normal(tensor,gain=1): print("****************xavier_normal*****************") fan_in,std) return tensor.normal_(0,std) w = torch.Tensor(3,5) xavier_uniform=xavier_uniform(tensor=w,gain=1) print("xavier_uniform",xavier_uniform) xavier_normal=xavier_normal(tensor=w,gain=1) print("xavier_normal",xavier_normal) ‘‘‘ ****************xavier_uniform***************** ***********_calculate_fan_in_and_fan_out**************** dimensions 2 fan_in 5 fan_out 3 fan_in 5 fan_out 3 std 0.5 a 0.8660254037844386 xavier_uniform tensor([[-0.0043,-0.6705,-0.4981,-0.6935,0.3967],[ 0.3643,0.2465,0.6906,-0.2256,-0.7046],[ 0.6660,0.7381,0.5887,0.0423,0.2840]]) ****************xavier_normal***************** ***********_calculate_fan_in_and_fan_out**************** dimensions 2 fan_in 5 fan_out 3 fan_in 5 fan_out 3 std 0.5 xavier_normal tensor([[ 0.6554,-0.3533,-0.2101,-0.0362,0.3919],[ 0.4505,-0.8219,0.5489,0.7568,-0.5317],[-0.2396,0.1093,-0.3372,-0.1136,0.4452]]) ‘‘‘ (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |