加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

flask – ValueError:Tensor’A’必须与Tensor’B’在同一图表

发布时间:2020-12-20 11:57:35 所属栏目:Python 来源:网络整理
导读:我正在使用keras的预训练模型,并且在调用ResNet50时出现错误(权重=’imagenet’). 我在flask服务器中有以下代码: def getVGG16Prediction(img_path): model = VGG16(weights='imagenet',include_top=True) img = image.load_img(img_path,target_size=(224,
我正在使用keras的预训练模型,并且在调用ResNet50时出现错误(权重=’imagenet’).
我在flask服务器中有以下代码:

def getVGG16Prediction(img_path):

    model = VGG16(weights='imagenet',include_top=True)
    img = image.load_img(img_path,target_size=(224,224))
    x = image.img_to_array(img)
    x = np.expand_dims(x,axis=0)
    x = preprocess_input(x)

    pred = model.predict(x)
    return sort(decode_predictions(pred,top=3)[0])


def getResNet50Prediction(img_path):

    model = ResNet50(weights='imagenet') #ERROR HERE
    img = image.load_img(img_path,axis=0)
    x = preprocess_input(x)

    preds = model.predict(x)
    return decode_predictions(preds,top=3)[0]

在main中调用时,它工作正常

if __name__ == "__main__":
    STATIC_PATH = os.getcwd()+"/static"
    print(getVGG16Prediction(STATIC_PATH+"/18.jpg"))
    print(getResNet50Prediction(STATIC_PATH+"/18.jpg"))

但是,当我从烧瓶POST功能调用它时,ValueError会上升:

@app.route("/uploadMultipleImages",methods=["POST"])
def uploadMultipleImages():
    uploaded_files = request.files.getlist("file[]")
    weight = request.form.get("weight")

    for file in uploaded_files:
        path = os.path.join(STATIC_PATH,file.filename)
        file.save(os.path.join(STATIC_PATH,file.filename))
        result = getResNet50Prediction(path)

完整错误如下:

ValueError: Tensor(“cond/pred_id:0”,dtype=bool) must be from the same
graph as Tensor(“batchnorm/add_1:0”,shape=(?,112,64),
dtype=float32)

任何评论或建议都非常感谢.谢谢.

解决方法

您需要打开不同的会话并指定每个会话使用哪个图表,否则Keras会将每个图表替换为默认值.

from tensorflow import Graph,Session,load_model
from Keras import backend as K

加载图表:

graph1 = Graph()
    with graph1.as_default():
        session1 = Session()
        with session1.as_default():
            model = load_model(foo.h5)

graph2 = Graph()
    with graph2.as_default():
        session2 = Session()
        with session2.as_default():
            model2 = load_model(foo2.h5)

预测/使用图表:

K.set_session(session1)
    with graph1.as_default():
        result = model.predict(data)

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读