加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

获取输入数组和输出数组项以将模型转换为tflite格式

发布时间:2020-12-17 17:36:48 所属栏目:Python 来源:网络整理
导读:PS.请不要将我指向converting Keras model directly to tflite,因为我的.h5文件无法直接转换为.tflite.我以某种方式设法将我的.h5文件转换为.pb 我关注了this Jupyter笔记本,使用Keras进行面部识别.然后,我将模型保存到model.h5文件,然后使用this将其转换为

PS.请不要将我指向converting Keras model directly to tflite,因为我的.h5文件无法直接转换为.tflite.我以某种方式设法将我的.h5文件转换为.pb

我关注了this Jupyter笔记本,使用Keras进行面部识别.然后,我将模型保存到model.h5文件,然后使用this将其转换为冻结图model.pb.

现在我想在Android中使用我的tensorflow文件.为此,我需要Tensorflow Lite,这需要我将模型转换为.tflite格式.

为此,我尝试遵循其官方指南here.如您所见,它需要input_array和output_array数组.如何从我的model.pb文件中获得这些内容的详细信息?

最佳答案
输入数组和输出数组是分别存储输入和输出张量的数组.

They intend to inform the TFLiteConverter about the input and output tensors which will be used at the time of inference.

对于Keras模型,

输入张量是第一层的占位符张量.

input_tensor = model.layers[0].input

输出张量可以与激活函数有关.

output_tensor = model.layers[ LAST_LAYER_INDEX ].output

对于冻结图,

import tensorflow as tf
gf = tf.GraphDef()   
m_file = open('model.pb','rb')
gf.ParseFromString(m_file.read())

我们得到节点的名称,

for n in gf.node:
    print( n.name )

为了得到张量,

tensor = n.op

输入张量可以是占位符张量.输出张量是您使用session.run()运行的张量

对于转换,我们得到

input_array =[ input_tensor ]
output_array = [ output_tensor ]

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读