加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – 获取具有可变序列长度的激活时的Tensorflow GRU单元错

发布时间:2020-12-20 13:11:55 所属栏目:Python 来源:网络整理
导读:我想在一些时间序列数据上运行GRU单元格,以根据最后一层中的激活对它们进行聚类.我对GRU单元实现做了一个小改动 def __call__(self,inputs,state,scope=None):"""Gated recurrent unit (GRU) with nunits cells."""with vs.variable_scope(scope or type(sel
我想在一些时间序列数据上运行GRU单元格,以根据最后一层中的激活对它们进行聚类.我对GRU单元实现做了一个小改动

def __call__(self,inputs,state,scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
  with vs.variable_scope("Gates"):  # Reset gate and update gate.
    # We start with bias of 1.0 to not reset and not update.
    r,u = array_ops.split(1,2,linear([inputs,state],2 * self._num_units,True,1.0))
    r,u = sigmoid(r),sigmoid(u)
  with vs.variable_scope("Candidate"):
    c = tanh(linear([inputs,r * state],self._num_units,True))
  new_h = u * state + (1 - u) * c

  # store the activations,everything else is the same
  self.activations = [r,u,c]
return new_h,new_h

在此之后,我将以下面的方式连接激活,然后在调用此GRU单元的脚本中返回它们

@property
def activations(self):
    return self._activations


@activations.setter
def activations(self,activations_array):
    print "PRINT THIS"         
    concactivations = tf.concat(concat_dim=0,values=activations_array,name='concat_activations')
    self._activations = tf.reshape(tensor=concactivations,shape=[-1],name='flatten_activations')

我以下面的方式调用GRU单元

outputs,state = rnn.rnn(cell=cell,inputs=x,initial_state=initial_state,sequence_length=s)

其中s是批处理长度数组,其中包含输入批处理的每个元素中的时间戳数.

最后我拿到了

fetched = sess.run(fetches=cell.activations,feed_dict=feed_dict)

执行时我收到以下错误

Traceback(最近一次调用最后一次):
??文件“xxx.py”,第162行,in
????fetched = sess.run(fetches = cell.activations,feed_dict = feed_dict)
??文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第315行,在运行中
????return self._run(None,fetches,feed_dict)
??在_run中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第511行
????feed_dict_string)
??在_do_run中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第564行
????target_list)
??在_do_call中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第588行
????six.reraise(e_type,e_value,e_traceback)
??在_do_call中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第571行
????return fn(* args)
??在_run_fn中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第555行

return tf_session.TF_Run(session,feed_dict,fetch_list,target_list)
tensorflow.python.pywrap_tensorflow.StatusNotOK:无效参数:为RNN / cond_396 / ClusterableGRUCell / flatten_activations返回的张量:0无效.

有人可以通过传递可变长度序列来了解如何在最后一步从GRU单元获取激活吗?谢谢.

解决方法

要从最后一步获取激活,您需要将激活作为状态的一部分,由tf.rnn返回.

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读