加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – Tensorflow条件抛出值错误

发布时间:2020-12-20 11:51:33 所属栏目:Python 来源:网络整理
导读:我试图使用张量流的条件,我收到错误: ValueError: Shapes (1,) and () are not compatible 下面是我使用的抛出错误的代码. 它说错误是有条件的 import tensorflow as tfimport numpy as npX = tf.constant([1,0])Y = tf.constant([0,1])BOTH = tf.constant(
我试图使用张量流的条件,我收到错误:

ValueError: Shapes (1,) and () are not compatible

下面是我使用的抛出错误的代码.
它说错误是有条件的

import tensorflow as tf
import numpy as np

X = tf.constant([1,0])
Y = tf.constant([0,1])
BOTH = tf.constant([1,1])
WORKING = tf.constant(1)

def create_mult_func(tf,amount,list):
    def f1():
        return tf.scalar_mul(amount,list)
    return f1

def create_no_op_func(tensor):
    def f1():
        return tensor
    return f1

def stretch(tf,points,dim,amount):
    """points is a 2 by ??? tensor,dim is a 1 by 2 tensor,amount is tensor scalor"""
    x_list,y_list = tf.split(0,2,points)
    x_stretch,y_stretch = tf.split(1,dim)
    is_stretch_X = tf.equal(x_stretch,WORKING,name="is_stretch_x")
    is_stretch_Y = tf.equal(y_stretch,name="is_stretch_Y")
    x_list_stretched = tf.cond(is_stretch_X,create_mult_func(tf,x_list),create_no_op_func(x_list))
    y_list_stretched = tf.cond(is_stretch_Y,y_list),create_no_op_func(y_list))
    return tf.concat(1,[x_list_stretched,y_list_stretched])

example_points = np.array([[1,1],[2,2],[3,3]],dtype=np.float32)
example_point_list = tf.placeholder(tf.float32)

result = stretch(tf,example_point_list,X,1)
sess = tf.Session()

with tf.Session() as sess:
    result = sess.run(result,feed_dict={example_point_list: example_points})
    print(result)

堆栈跟踪:

File "/path/test2.py",line 36,in <module>
    result = stretch(tf,1)
  File "/path/test2.py",line 28,in stretch
    create_mult_func(tf,create_no_op_func(x_list))
  File "/path/tensorflow/python/ops/control_flow_ops.py",line 1142,in cond
    p_2,p_1 = switch(pred,pred)
  File "/path/tensorflow/python/ops/control_flow_ops.py",line 203,in switch
    return gen_control_flow_ops._switch(data,pred,name=name)
  File "/path/tensorflow/python/ops/gen_control_flow_ops.py",line 297,in _switch
    return _op_def_lib.apply_op("Switch",data=data,pred=pred,name=name)
  File "/path/tensorflow/python/ops/op_def_library.py",line 655,in apply_op
    op_def=op_def)
  File "/path/tensorflow/python/framework/ops.py",line 2156,in create_op
    set_shapes_for_outputs(ret)
  File "/path/tensorflow/python/framework/ops.py",line 1612,in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/path/tensorflow/python/ops/control_flow_ops.py",line 2032,in _SwitchShape
    unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
  File "/path/tensorflow/python/framework/tensor_shape.py",line 554,in merge_with
    (self,other))
ValueError: Shapes (1,) and () are not compatible

我已经尝试将WORKING更改为数组而不是标量.

我相信问题是tf.equal返回一个int32而不是它应该根据文档返回的bool

解决方法

问题在于tf.cond的第一个参数.从文档 here,关于tf.cond的第一个参数的类型:

pred: A scalar determining whether to return the result of fn1 or fn2.

请注意,它必须是标量.你正在使用比较张量和张量的结果,它给你一个(1,)张量,而不是标量.您可以使用tf.reshape运算符将其转换为标量,如下所示:

t = tf.equal(x_stretch,name="is_stretch_x")
x_list_stretched = tf.cond(tf.reshape(t,[]),create_no_op_func(x_list))

完整的工作程序:

import tensorflow as tf
import numpy as np

X = tf.constant([1,y_stretch = tf.split(0,name="is_stretch_Y")
    x_list_stretched = tf.cond(tf.reshape(is_stretch_X,create_no_op_func(x_list))
    y_list_stretched = tf.cond(tf.reshape(is_stretch_Y,create_no_op_func(y_list))
    return tf.pack([x_list_stretched,2]],feed_dict={example_point_list: example_points})
    print(result)

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读