加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – 如何使用TensorFlow Graphkeys获取所有权重?

发布时间:2020-12-20 12:02:02 所属栏目:Python 来源:网络整理
导读:Tensorflow定义了预设的集合,如下所示: https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections 我目前正在使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)来获取所有变量[*的名字;如果他们不是那么他们就不会被展
Tensorflow定义了预设的集合,如下所示: https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections

我目前正在使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)来获取所有变量[*的名字;如果他们不是那么他们就不会被展示,即使他们存在].

类似地,我期望tf.get_collection(tf.GraphKeys.WEIGHTS)输出权重列表,但它是一个空数组.这也适用于GraphKeys.BIASES和.ACTIVATIONS.

这里发生了什么?

在我看来,这里似乎有两种可能性.首先,它们实际上从未自动定义,只是推荐的集合名称.其次,我的网络非常破碎,但似乎并非如此.

有人有这方面的经验吗?

解决方法

默认情况下,所有变量都绑定到tf.GraphKeys.GLOBAL_VARIABLES集合.方便的方法是将每个权重设置为集合tf.GraphKeys.WEIGHTS,如下所示:

In [2]: w = tf.Variable([1,2,3],collections=[tf.GraphKeys.WEIGHTS],dtype=tf.float32)
In [3]: w2 = tf.Variable([11,22,32],dtype=tf.float32)

然后你可以通过以下方式获取它们:

tf.get_collection_ref(tf.GraphKeys.WEIGHTS)

这是权重:

[<tf.Variable 'Variable:0' shape=(3,) dtype=float32_ref>,<tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读