python – 获取具有可变序列长度的激活时的Tensorflow GRU单元错
我想在一些时间序列数据上运行GRU单元格,以根据最后一层中的激活对它们进行聚类.我对GRU单元实现做了一个小改动
def __call__(self,inputs,state,scope=None): """Gated recurrent unit (GRU) with nunits cells.""" with vs.variable_scope(scope or type(self).__name__): # "GRUCell" with vs.variable_scope("Gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. r,u = array_ops.split(1,2,linear([inputs,state],2 * self._num_units,True,1.0)) r,u = sigmoid(r),sigmoid(u) with vs.variable_scope("Candidate"): c = tanh(linear([inputs,r * state],self._num_units,True)) new_h = u * state + (1 - u) * c # store the activations,everything else is the same self.activations = [r,u,c] return new_h,new_h 在此之后,我将以下面的方式连接激活,然后在调用此GRU单元的脚本中返回它们 @property def activations(self): return self._activations @activations.setter def activations(self,activations_array): print "PRINT THIS" concactivations = tf.concat(concat_dim=0,values=activations_array,name='concat_activations') self._activations = tf.reshape(tensor=concactivations,shape=[-1],name='flatten_activations') 我以下面的方式调用GRU单元 outputs,state = rnn.rnn(cell=cell,inputs=x,initial_state=initial_state,sequence_length=s) 其中s是批处理长度数组,其中包含输入批处理的每个元素中的时间戳数. 最后我拿到了 fetched = sess.run(fetches=cell.activations,feed_dict=feed_dict) 执行时我收到以下错误 Traceback(最近一次调用最后一次): return tf_session.TF_Run(session,feed_dict,fetch_list,target_list) 有人可以通过传递可变长度序列来了解如何在最后一步从GRU单元获取激活吗?谢谢. 解决方法
要从最后一步获取激活,您需要将激活作为状态的一部分,由tf.rnn返回.
(编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |