python – PyTorch,nn.Sequential(),访问nn.Sequential()中特定
发布时间:2020-12-20 12:34:06 所属栏目:Python 来源:网络整理
导读:这应该是一个快速的.当我在PyTorch中使用预定义模块时,我通常可以非常轻松地访问其权重.但是,如果我先将模块包装在nn.Sequential()中,如何访问它们?请看下面的玩具示例 class My_Model_1(nn.Module): def __init__(self,D_in,D_out): super(My_Model_1,self
这应该是一个快速的.当我在PyTorch中使用预定义模块时,我通常可以非常轻松地访问其权重.但是,如果我先将模块包装在nn.Sequential()中,如何访问它们?请看下面的玩具示例
class My_Model_1(nn.Module): def __init__(self,D_in,D_out): super(My_Model_1,self).__init__() self.layer = nn.Linear(D_in,D_out) def forward(self,x): out = self.layer(x) return out class My_Model_2(nn.Module): def __init__(self,D_out): super(My_Model_2,self).__init__() self.layer = nn.Sequential(nn.Linear(D_in,D_out)) def forward(self,x): out = self.layer(x) return out model_1 = My_Model_1(10,10) print(model_1.layer.weight) model_2 = My_Model_2(10,10) # How do I print the weights now? # model_2.layer.0.weight doesn't work. 解决方法
访问权重的一个简单权重是使用模型的state_dict().
这适用于您的情况: for k,v in model_2.state_dict().iteritems(): print("Layer {}".format(k)) print(v) 另一个选择是获取modules()迭代器.如果您事先知道图层的类型,这也应该有效: for layer in model_2.modules(): if isinstance(layer,nn.Linear): print(layer.weight) (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |