python – tensorflow:简单LSTM网络的共享变量错误
发布时间:2020-12-20 12:10:32 所属栏目:Python 来源:网络整理
导读:我正在尝试构建一个最简单的LSTM网络.只是希望它预测序列np_input_data中的下一个值. import tensorflow as tffrom tensorflow.python.ops import rnn_cellimport numpy as npnum_steps = 3num_units = 1np_input_data = [np.array([[1.],[2.]]),np.array([[
我正在尝试构建一个最简单的LSTM网络.只是希望它预测序列np_input_data中的下一个值.
import tensorflow as tf from tensorflow.python.ops import rnn_cell import numpy as np num_steps = 3 num_units = 1 np_input_data = [np.array([[1.],[2.]]),np.array([[2.],[3.]]),np.array([[3.],[4.]])] batch_size = 2 graph = tf.Graph() with graph.as_default(): tf_inputs = [tf.placeholder(tf.float32,[batch_size,1]) for _ in range(num_steps)] lstm = rnn_cell.BasicLSTMCell(num_units) initial_state = state = tf.zeros([batch_size,lstm.state_size]) loss = 0 for i in range(num_steps-1): output,state = lstm(tf_inputs[i],state) loss += tf.reduce_mean(tf.square(output - tf_inputs[i+1])) with tf.Session(graph=graph) as session: tf.initialize_all_variables().run() feed_dict={tf_inputs[i]: np_input_data[i] for i in range(len(np_input_data))} loss = session.run(loss,feed_dict=feed_dict) print(loss) 口译员返回: ValueError: Variable BasicLSTMCell/Linear/Matrix already exists,disallowed. Did you mean to set reuse=True in VarScope? Originally defined at: output,state) 我做错了什么? 解决方法
在这里打电话给lstm:
for i in range(num_steps-1): output,state) 将尝试每次迭代创建具有相同名称的变量,除非您另有说明.您可以使用tf.variable_scope执行此操作 with tf.variable_scope("myrnn") as scope: for i in range(num_steps-1): if i > 0: scope.reuse_variables() output,state) 第一次迭代创建表示LSTM参数的变量,每次后续迭代(在调用reuse_variables之后)将只按名称在范围内查找它们. (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |