加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – 如何在Tensorflow中加载预训练的LSTM模型权重

发布时间:2020-12-20 13:12:44 所属栏目:Python 来源:网络整理
导读:我想在Tensorflow中实现具有预训练权重的LSTM模型.这些重量可能来自Caffee或Torch. 我发现文件rnn_cell.py中有LSTM单元,例如rnn_cell.BasicLSTMCell和rnn_cell.MultiRNNCell.但是如何为这些LSTM单元加载预训练的权重. 解决方法 这是一个加载预先训练的Caffe
我想在Tensorflow中实现具有预训练权重的LSTM模型.这些重量可能来自Caffee或Torch.
我发现文件rnn_cell.py中有LSTM单元,例如rnn_cell.BasicLSTMCell和rnn_cell.MultiRNNCell.但是如何为这些LSTM单元加载预训练的权重.

解决方法

这是一个加载预先训练的Caffe模型的解决方案.见 full code here,在 this thread的讨论中引用.

net_caffe = caffe.Net(prototxt,caffemodel,caffe.TEST)
caffe_layers = {}

for i,layer in enumerate(net_caffe.layers):
    layer_name = net_caffe._layer_names[i]
    caffe_layers[layer_name] = layer

def caffe_weights(layer_name):
    layer = caffe_layers[layer_name]
    return layer.blobs[0].data

def caffe_bias(layer_name):
    layer = caffe_layers[layer_name]
    return layer.blobs[1].data

#tensorflow uses [filter_height,filter_width,in_channels,out_channels] 2-3-1-0 
#caffe uses [out_channels,filter_height,filter_width] 0-1-2-3
def caffe2tf_filter(name):
    f = caffe_weights(name)
    return f.transpose((2,3,1,0))

class ModelFromCaffe():
    def get_conv_filter(self,name):
        w = caffe2tf_filter(name)
        return tf.constant(w,dtype=tf.float32,name="filter")

    def get_bias(self,name):
        b = caffe_bias(name)
        return tf.constant(b,name="bias")

    def get_fc_weight(self,name):
        cw = caffe_weights(name)
        if name == "fc6":
            assert cw.shape == (4096,25088)
            cw = cw.reshape((4096,512,7,7)) 
            cw = cw.transpose((2,0))
            cw = cw.reshape(25088,4096)
        else:
            cw = cw.transpose((1,0))

        return tf.constant(cw,name="weight")

images = tf.placeholder("float",[None,224,3],name="images")
m = ModelFromCaffe()

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  batch = cat.reshape((1,3))
  out = sess.run([m.prob,m.relu1_1,m.pool5,m.fc6],feed_dict={ images: batch })
...

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读