python-2.7 – 如何在Tensorflow中使用CheckpointReader恢复变量
发布时间:2020-12-20 11:53:50 所属栏目:Python 来源:网络整理
导读:我正在尝试从检查点文件中恢复一些变量,如果相同的变量名在当前模型中. 我发现有一些方法,如 Tensorfow Github 所以我想要做的是使用has_tensor(“variable.name”)检查检查点文件中的变量名,如下所示, ... reader = tf.train.NewCheckpointReader(ckpt_path
我正在尝试从检查点文件中恢复一些变量,如果相同的变量名在当前模型中.
我发现有一些方法,如 Tensorfow Github 所以我想要做的是使用has_tensor(“variable.name”)检查检查点文件中的变量名,如下所示, ... reader = tf.train.NewCheckpointReader(ckpt_path) for v in tf.trainable_variables(): print v.name if reader.has_tensor(v.name): print 'has tensor' ... 但我发现v.name返回变量名和冒号.例如,我有变量名W_o和b_o然后v.name返回W_o:0,b_o:0. 但是reader.has_tensor()需要不带冒号的名称和数字为W_o,b_o. 我的问题是:如何删除变量名末尾的冒号和数字以读取变量? 解决方法
您可以使用
string.split()来获取张量名称:
... reader = tf.train.NewCheckpointReader(ckpt_path) for v in tf.trainable_variables(): tensor_name = v.name.split(':')[0] print tensor_name if reader.has_tensor(tensor_name): print 'has tensor' ... 接下来,让我用一个例子来说明如何从.cpkt文件中恢复每个可能的变量.首先,让我们在tmp.ckpt中保存v2和v3: import tensorflow as tf v1 = tf.Variable(tf.ones([1]),name='v1') v2 = tf.Variable(2 * tf.ones([1]),name='v2') v3 = tf.Variable(3 * tf.ones([1]),name='v3') saver = tf.train.Saver({'v2': v2,'v3': v3}) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) saver.save(sess,'tmp.ckpt') 这就是我将如何恢复tmp.ckpt中显示的每个变量(属于新图形): with tf.Graph().as_default(): assert len(tf.trainable_variables()) == 0 v1 = tf.Variable(tf.zeros([1]),name='v1') v2 = tf.Variable(tf.zeros([1]),name='v2') reader = tf.train.NewCheckpointReader('tmp.ckpt') restore_dict = dict() for v in tf.trainable_variables(): tensor_name = v.name.split(':')[0] if reader.has_tensor(tensor_name): print('has tensor ',tensor_name) restore_dict[tensor_name] = v saver = tf.train.Saver(restore_dict) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) saver.restore(sess,'tmp.ckpt') print(sess.run([v1,v2])) # prints [array([ 0.],dtype=float32),array([ 2.],dtype=float32)] 此外,您可能希望确保形状和dtypes匹配. (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |