python – 在Tensorflow Object Detection API中打印类名和分数
发布时间:2020-12-20 12:04:29 所属栏目:Python 来源:网络整理
导读:我正在使用Tensorflow对象检测API一切正常但我想打印一个dict或数组,其格式如下{Object name,Score}或类似的东西,我需要的是对象名称和分数. 我尝试使用以下代码: with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: # D
我正在使用Tensorflow对象检测API一切正常但我想打印一个dict或数组,其格式如下{Object name,Score}或类似的东西,我需要的是对象名称和分数.
我尝试使用以下代码: with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: # Definite input and output Tensors for detection_graph image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') # Each box represents a part of the image where a particular object was detected. detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') # Each score represent how level of confidence for each of the objects. # Score is shown on the result image,together with the class label. detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') for image_path in TEST_IMAGE_PATHS: image = Image.open(image_path) # the array based representation of the image will be used later in order to prepare the # result image with boxes and labels on it. image_np = load_image_into_numpy_array(image) # Expand dimensions since the model expects images to have shape: [1,None,3] image_np_expanded = np.expand_dims(image_np,axis=0) # Actual detection. (boxes,scores,classes,num) = sess.run( [detection_boxes,detection_scores,detection_classes,num_detections],feed_dict={image_tensor: image_np_expanded}) print ([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5]) threshold = 0.5 # in order to get higher percentages you need to lower this number; usually at 0.01 you get 100% predicted objects print(len(np.where(scores[0] > threshold)[0])/num_detections[0]) 这部分是有效的 print ([category_index.get(value) for index,index] > 0.5]) 它正在打印[{‘name’:’computer’,’id’:1}]他们是否能以任何方式将该对象的分数添加到字典? 我在他们使用的Stackoverflow上看到了另一个问题: threshold = 0.5 # in order to get higher percentages you need to lower this number; usually at 0.01 you get 100% predicted objects print(len(np.where(scores[0] > threshold)[0])/num_detections[0]) 这给了我Tensor(“truediv:0”,dtype = float32)但是如果它工作的话虽然它还不够,因为我没有对象的名字. 谢谢 解决方法
所以这是适合我的解决方案. (如果您仍在寻找解决方案,那就是)
# The following code replaces the 'print ([category_index...' statement objects = [] for index,value in enumerate(classes[0]): object_dict = {} if scores[0,index] > threshold: object_dict[(category_index.get(value)).get('name').encode('utf8')] = scores[0,index] objects.append(object_dict) print objects (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |
相关内容
- 利用Python获取赶集网招聘信息前篇
- python – 如何使用Marshmallow序列化MongoDB ObjectId?
- python open cv loading image
- python – 针对.pyc文件的编译器?
- 关于exponents和int的Python问题
- python – cartopy set_xlabel set_ylabel(不是ticklabels)
- python – TypeError:Reduce()没有初始值的空序列
- python读取文件并写入内容到文件
- python基础3之文件操作、字符编码解码、函数介绍
- 明天就是七夕了,用Python做了个可能会被女朋友打死的礼物!
推荐文章
站长推荐
热点阅读