python – Tensorflow条件抛出值错误
发布时间:2020-12-20 11:51:33 所属栏目:Python 来源:网络整理
导读:我试图使用张量流的条件,我收到错误: ValueError: Shapes (1,) and () are not compatible 下面是我使用的抛出错误的代码. 它说错误是有条件的 import tensorflow as tfimport numpy as npX = tf.constant([1,0])Y = tf.constant([0,1])BOTH = tf.constant(
我试图使用张量流的条件,我收到错误:
ValueError: Shapes (1,) and () are not compatible 下面是我使用的抛出错误的代码. import tensorflow as tf import numpy as np X = tf.constant([1,0]) Y = tf.constant([0,1]) BOTH = tf.constant([1,1]) WORKING = tf.constant(1) def create_mult_func(tf,amount,list): def f1(): return tf.scalar_mul(amount,list) return f1 def create_no_op_func(tensor): def f1(): return tensor return f1 def stretch(tf,points,dim,amount): """points is a 2 by ??? tensor,dim is a 1 by 2 tensor,amount is tensor scalor""" x_list,y_list = tf.split(0,2,points) x_stretch,y_stretch = tf.split(1,dim) is_stretch_X = tf.equal(x_stretch,WORKING,name="is_stretch_x") is_stretch_Y = tf.equal(y_stretch,name="is_stretch_Y") x_list_stretched = tf.cond(is_stretch_X,create_mult_func(tf,x_list),create_no_op_func(x_list)) y_list_stretched = tf.cond(is_stretch_Y,y_list),create_no_op_func(y_list)) return tf.concat(1,[x_list_stretched,y_list_stretched]) example_points = np.array([[1,1],[2,2],[3,3]],dtype=np.float32) example_point_list = tf.placeholder(tf.float32) result = stretch(tf,example_point_list,X,1) sess = tf.Session() with tf.Session() as sess: result = sess.run(result,feed_dict={example_point_list: example_points}) print(result) 堆栈跟踪: File "/path/test2.py",line 36,in <module> result = stretch(tf,1) File "/path/test2.py",line 28,in stretch create_mult_func(tf,create_no_op_func(x_list)) File "/path/tensorflow/python/ops/control_flow_ops.py",line 1142,in cond p_2,p_1 = switch(pred,pred) File "/path/tensorflow/python/ops/control_flow_ops.py",line 203,in switch return gen_control_flow_ops._switch(data,pred,name=name) File "/path/tensorflow/python/ops/gen_control_flow_ops.py",line 297,in _switch return _op_def_lib.apply_op("Switch",data=data,pred=pred,name=name) File "/path/tensorflow/python/ops/op_def_library.py",line 655,in apply_op op_def=op_def) File "/path/tensorflow/python/framework/ops.py",line 2156,in create_op set_shapes_for_outputs(ret) File "/path/tensorflow/python/framework/ops.py",line 1612,in set_shapes_for_outputs shapes = shape_func(op) File "/path/tensorflow/python/ops/control_flow_ops.py",line 2032,in _SwitchShape unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) File "/path/tensorflow/python/framework/tensor_shape.py",line 554,in merge_with (self,other)) ValueError: Shapes (1,) and () are not compatible 我已经尝试将WORKING更改为数组而不是标量. 我相信问题是tf.equal返回一个int32而不是它应该根据文档返回的bool 解决方法
问题在于tf.cond的第一个参数.从文档
here,关于tf.cond的第一个参数的类型:
pred: A scalar determining whether to return the result of fn1 or fn2. 请注意,它必须是标量.你正在使用比较张量和张量的结果,它给你一个(1,)张量,而不是标量.您可以使用 t = tf.equal(x_stretch,name="is_stretch_x") x_list_stretched = tf.cond(tf.reshape(t,[]),create_no_op_func(x_list)) 完整的工作程序: import tensorflow as tf import numpy as np X = tf.constant([1,y_stretch = tf.split(0,name="is_stretch_Y") x_list_stretched = tf.cond(tf.reshape(is_stretch_X,create_no_op_func(x_list)) y_list_stretched = tf.cond(tf.reshape(is_stretch_Y,create_no_op_func(y_list)) return tf.pack([x_list_stretched,2]],feed_dict={example_point_list: example_points}) print(result) (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |