加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – Tensorflow和cifar 10,测试单个图像

发布时间:2020-12-20 13:14:43 所属栏目:Python 来源:网络整理
导读:我试图用tensorflow的cifar-10预测单个图像的类. 我找到了这个代码,但它失败了这个错误: 分配要求两个张量的形状匹配. lhs shape = [18,384] rhs shape = [2304,384] 我理解这是因为批次的大小只有1.(使用expand_dims我创建一个假批次.) 但我不知道如何解决
我试图用tensorflow的cifar-10预测单个图像的类.

我找到了这个代码,但它失败了这个错误:

分配要求两个张量的形状匹配. lhs shape = [18,384] rhs shape = [2304,384]
我理解这是因为批次的大小只有1.(使用expand_dims我创建一个假批次.)

但我不知道如何解决这个问题?

我到处搜索但没有解决方案..
提前致谢!

from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
width = 24
height = 24

categories =  ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ]

filename = "path/to/jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename,format='JPEG',subsampling=0,quality=100)
input_img = tf.image.decode_jpeg(tf.read_file(filename),channels=3)
tf_cast = tf.cast(input_img,tf.float32)
float_image = tf.image.resize_image_with_crop_or_pad(tf_cast,height,width)
images = tf.expand_dims(float_image,0)
logits = cifar10.inference(images)
_,top_k_pred = tf.nn.top_k(logits,k=5)
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
if ckpt and ckpt.model_checkpoint_path:
    print("ckpt.model_checkpoint_path ",ckpt.model_checkpoint_path)
    saver.restore(sess,ckpt.model_checkpoint_path)
else:
    print('No checkpoint file found')
    exit(0)
sess.run(init_op)
_,top_indices = sess.run([_,top_k_pred])
for key,value in enumerate(top_indices[0]):
    print (categories[value] + "," + str(_[0][key]))

编辑

我尝试放置一个占位符,在第一个形状中使用None,但是我收到了这个错误:
必须完全定义新变量(local3 / weights)的形状,而不是(?,384).

现在我真的迷路了……
这是新代码:

from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
import itertools
width = 24
height = 24

categories = [ "airplane","truck" ]

filename = "toto.jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename,quality=100)
x = tf.placeholder(tf.float32,[None,24,3])
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
 # Restore variables from training checkpoint.
    input_img = tf.image.decode_jpeg(tf.read_file(filename),channels=3)
    tf_cast = tf.cast(input_img,tf.float32)
    float_image = tf.image.resize_image_with_crop_or_pad(tf_cast,width)
    images = tf.expand_dims(float_image,0)
    i = images.eval()
    print (i)
    sess.run(init_op,feed_dict={x: i})
    logits = cifar10.inference(x)
    _,k=5)
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar10.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
    if ckpt and ckpt.model_checkpoint_path:
        print("ckpt.model_checkpoint_path ",ckpt.model_checkpoint_path)
        saver.restore(sess,ckpt.model_checkpoint_path)
    else:
        print('No checkpoint file found')
        exit(0)
    _,top_k_pred])
    for key,value in enumerate(top_indices[0]):
        print (categories[value] + "," + str(_[0][key]))

解决方法

我认为这是因为tf.Variable或tf.get_variable获取的变量必须具有完整定义的形状.您可以检查代码并提供完整定义的形状.

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读