多线程 – Keras Tensorflow – 从多个线程预测时的异常
我正在使用keras 2.0.8和tensorflow 1.3.0后端.
我在类init中加载一个模型,然后用它来预测多线程. import tensorflow as tf from keras import backend as K from keras.models import load_model class CNN: def __init__(self,model_path): self.cnn_model = load_model(model_path) self.session = K.get_session() self.graph = tf.get_default_graph() def query_cnn(self,data): X = self.preproccesing(data) with self.session.as_default(): with self.graph.as_default(): return self.cnn_model.predict(X) 我初始化CNN一次,query_cnn方法从多个线程发生. 我在日志中得到的例外是: File "/home/*/Similarity/CNN.py",line 43,in query_cnn return self.cnn_model.predict(X) File "/usr/local/lib/python3.5/dist-packages/keras/models.py",line 913,in predict return self.model.predict(x,batch_size=batch_size,verbose=verbose) File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1713,in predict verbose=verbose,steps=steps) File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1269,in _predict_loop batch_outs = f(ins_batch) File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py",line 2273,in __call__ **self.session_kwargs) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 895,in run run_metadata_ptr) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1124,in _run feed_dict_tensor,options,run_metadata) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1321,in _do_run options,line 1340,in _do_call raise type(e)(node_def,op,message) tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps 代码在大多数情况下工作正常,它可能是多线程的一些问题. 我该如何解决? 解决方法
确保在创建其他线程之前完成图形创建.
在图表上调用finalize()可以帮助您. def __init__(self,model_path): self.cnn_model = load_model(model_path) self.session = K.get_session() self.graph = tf.get_default_graph() self.graph.finalize() 更新1:finalize()将使您的图形为只读,以便可以安全地在多个线程中使用.作为副作用,它将帮助您找到无意的行为,有时还会发现内存泄漏,因为当您尝试修改图形时它会引发异常. 想象一下,你有一个线程可以做一个例如输入的热编码. (坏的例子:) def preprocessing(self,data): one_hot_data = tf.one_hot(data,depth=self.num_classes) return self.session.run(one_hot_data) 如果在图表中打印对象数量,您会发现它会随着时间的推移而增加 # amount of nodes in tf graph print(len(list(tf.get_default_graph().as_graph_def().node))) 但是,如果您首先定义图形不是这种情况(略微更好的代码): def preprocessing(self,data): # run pre-created operation with self.input as placeholder return self.session.run(self.one_hot_data,feed_dict={self.input: data}) 更新2:根据此thread,您需要在执行多线程之前在keras模型上调用model._make_predict_function().
更新的代码: def __init__(self,model_path): self.cnn_model = load_model(model_path) self.cnn_model._make_predict_function() # have to initialize before threading self.session = K.get_session() self.graph = tf.get_default_graph() self.graph.finalize() # make graph read-only 更新3:我做了一个预热概念的证明,因为_make_predict_function()似乎没有按预期工作. import tensorflow as tf from keras.layers import * from keras.models import * model = Sequential() model.add(Dense(256,input_shape=(2,))) model.add(Dense(1,activation='softmax')) model.compile(loss='mean_squared_error',optimizer='adam') model.save("dummymodel") 然后在另一个脚本中我加载了该模型并使其在多个线程上运行 import tensorflow as tf from keras import backend as K from keras.models import load_model import threading as t import numpy as np K.clear_session() class CNN: def __init__(self,model_path): self.cnn_model = load_model(model_path) self.cnn_model.predict(np.array([[0,0]])) # warmup self.session = K.get_session() self.graph = tf.get_default_graph() self.graph.finalize() # finalize def preproccesing(self,data): # dummy return data def query_cnn(self,data): X = self.preproccesing(data) with self.session.as_default(): with self.graph.as_default(): prediction = self.cnn_model.predict(X) print(prediction) return prediction cnn = CNN("dummymodel") th = t.Thread(target=cnn.query_cnn,kwargs={"data": np.random.random((500,2))}) th2 = t.Thread(target=cnn.query_cnn,2))}) th3 = t.Thread(target=cnn.query_cnn,2))}) th4 = t.Thread(target=cnn.query_cnn,2))}) th5 = t.Thread(target=cnn.query_cnn,2))}) th.start() th2.start() th3.start() th4.start() th5.start() th2.join() th.join() th3.join() th5.join() th4.join() 评论预热和最终确定的线条我能够重现你的第一个问题 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |
- java – 提前创建变量以在循环中使用(?)
- redis与spring整合使用的步骤实例教程
- RxJava – 何时使用带有create方法的Observable
- java – 如何在构建版本不是-SNAPSHOT时启用maven配置文件?
- 使用 RxJava 进行嵌套串行网络请求的一种方法
- Java动态规划之硬币找零问题实现代码
- java – 覆盖使用重写的toString()的toString()
- intellij idea 怎么全局搜索--转
- Java Collections类:sort()升序排序、reverse()降序排序、
- java – 为什么String.indexOf不使用异常,但是当没有找到子