python – 如何在TensorFlow import_graph_def期间更改输入的维
我的情景:
>定义RNN模型结构并使用具有固定批次大小和序列长度的输入对其进行训练. 问题:上述所有工作都有效,直到我将输入传递给使用批量大小或序列长度不同于训练时使用的原始大小的测试时间图.那时我得到这样的错误: InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1,5] vs. shape[1] = [2,7] [[Node: import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat = ConcatV2[N=2,T=DT_FLOAT,Tidx=DT_INT32,_device="/job:localhost/replica:0/task:0/cpu:0"](import/rnn/while/TensorArrayReadV3,import/rnn/while/Identity_2,import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat/axis)]] 为了说明和重现该问题,请考虑以下最小示例. > v1:使用任意批量大小和序列长度创建图表.这很好但不幸的是我必须在训练时使用固定的批量大小和序列长度,并且必须在测试时使用任意批量大小和序列长度,所以我不能使用这种简单的方法. 是否可以通过input_map参数将RNN图更改为tf.import_graph_def,以使输入不再具有固定的批量大小和序列长度? 以下代码适用于TensorFlow 1.1 RC2,可以与TensorFlow 1.0一起使用. import numpy import tensorflow as tf from tensorflow import graph_util as tf_graph_util from tensorflow.contrib import rnn as tfc_rnn def v1(data): with tf.Graph().as_default(): tf.set_random_seed(1) x = tf.placeholder(tf.float32,shape=(None,None,5)) _,s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7),x,dtype=tf.float32) with tf.Session() as session: session.run(tf.global_variables_initializer()) print session.run(s,feed_dict={x: data}) def v2a(): with tf.Graph().as_default(): tf.set_random_seed(1) x = tf.placeholder(tf.float32,shape=(2,3,5),name="x") _,dtype=tf.float32) with tf.Session() as session: session.run(tf.global_variables_initializer()) return tf_graph_util.convert_variables_to_constants( session,session.graph_def,[s.op.name]),s.name def v2ba((graph_def,s_name),data): with tf.Graph().as_default(): x,s = tf.import_graph_def(graph_def,return_elements=["x:0",s_name]) with tf.Session() as session: print '2ba',session.run(s,feed_dict={x: data}) def v2bb((graph_def,data): with tf.Graph().as_default(): x = tf.placeholder(tf.float32,5)) [s] = tf.import_graph_def(graph_def,input_map={"x:0": x},return_elements=[s_name]) with tf.Session() as session: print '2bb',feed_dict={x: data}) def v2bc((graph_def,return_elements=[s_name]) with tf.Session() as session: print '2bc',feed_dict={x: data}) def main(): data1 = numpy.random.random_sample((2,5)) data2 = numpy.random.random_sample((1,5)) v1(data1) model = v2a() v2ba(model,data1) v2bb(model,data1) v2bc(model,data2) if __name__ == "__main__": main() 解决方法
这是一个持续一段时间的张量流中的错误:您无法可靠地将具有已定义形状的占位符替换为具有(部分)未定义形状的另一个占位符.
你会发现here提交的相关问题,显然没有引起太多关注. (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |