加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python-2.7 – 如何在Tensorflow中使用CheckpointReader恢复变量

发布时间:2020-12-20 11:53:50 所属栏目:Python 来源:网络整理
导读:我正在尝试从检查点文件中恢复一些变量,如果相同的变量名在当前模型中. 我发现有一些方法,如 Tensorfow Github 所以我想要做的是使用has_tensor(“variable.name”)检查检查点文件中的变量名,如下所示, ... reader = tf.train.NewCheckpointReader(ckpt_path
我正在尝试从检查点文件中恢复一些变量,如果相同的变量名在当前模型中.
我发现有一些方法,如 Tensorfow Github

所以我想要做的是使用has_tensor(“variable.name”)检查检查点文件中的变量名,如下所示,

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    print v.name
    if reader.has_tensor(v.name):
        print 'has tensor'
...

但我发现v.name返回变量名和冒号.例如,我有变量名W_o和b_o然后v.name返回W_o:0,b_o:0.

但是reader.has_tensor()需要不带冒号的名称和数字为W_o,b_o.

我的问题是:如何删除变量名末尾的冒号和数字以读取变量?
有没有更好的方法来恢复这些变量?

解决方法

您可以使用 string.split()来获取张量名称:

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    print tensor_name
    if reader.has_tensor(tensor_name):
        print 'has tensor'
...

接下来,让我用一个例子来说明如何从.cpkt文件中恢复每个可能的变量.首先,让我们在tmp.ckpt中保存v2和v3:

import tensorflow as tf

v1 = tf.Variable(tf.ones([1]),name='v1')
v2 = tf.Variable(2 * tf.ones([1]),name='v2')
v3 = tf.Variable(3 * tf.ones([1]),name='v3')

saver = tf.train.Saver({'v2': v2,'v3': v3})

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.save(sess,'tmp.ckpt')

这就是我将如何恢复tmp.ckpt中显示的每个变量(属于新图形):

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros([1]),name='v1')
    v2 = tf.Variable(tf.zeros([1]),name='v2')

    reader = tf.train.NewCheckpointReader('tmp.ckpt')
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ',tensor_name)
            restore_dict[tensor_name] = v

    saver = tf.train.Saver(restore_dict)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.restore(sess,'tmp.ckpt')
        print(sess.run([v1,v2])) # prints [array([ 0.],dtype=float32),array([ 2.],dtype=float32)]

此外,您可能希望确保形状和dtypes匹配.

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读