deep_learning_Function_tf.train.ExponentialMovingAverage()滑
近来看batch normalization的代码时,遇到tf.train.ExponentialMovingAverage()函数,特此记录。 # 类,用于计算滑动平均 tf.train.ExponentialMovingAverage __init__( decay,num_updates=None,zero_debias=False,name=‘ExponentialMovingAverage‘) decay是衰减率。在创建ExponentialMovingAverage对象时,需要指定衰减率(decay),用于控制模型的更新速度。影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:
apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。average()和average_name()方法可以获取影子变量及其名称。 # 创建variables. var0 = tf.Variable(...) var1 = tf.Variable(...) # ... 使用variables去创建一个训练模型... ... # 创建一个使用the optimizer对的op. # 这是我们通常会使用作为一个training op. opt_op = opt.minimize(my_loss,[var0,var1]) # 创建一个ExponentialMovingAverage object ema = tf.train.ExponentialMovingAverage(decay=0.9999) # 创建the shadow variables,然后把ops加到maintain moving averages of var0 and var1. maintain_averages_op = ema.apply([var0,var1]) # 创建一个op,在每次训练之后用来更新the moving averages. # 用来代替the usual training op. with tf.control_dependencies([opt_op]): training_op = tf.group(maintain_averages_op) # run这个op获取当前时刻 ema_value get_var0_average_op = ema.average(var0) 例子: import tensorflow as tf import numpy as np v1 = tf.Variable(0,dtype=tf.float32) step = tf.Variable(tf.constant(0)) ema = tf.train.ExponentialMovingAverage(0.99,step) maintain_average = ema.apply([v1]) with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) print(sess.run([v1,ema.average(v1)])) #初始的值都为0 sess.run(tf.assign(v1,5)) #把v1变为5 sess.run(maintain_average) print(sess.run([v1,ema.average(v1)])) # decay=min(0.99,1/10)=0.1,v1=0.1*0+0.9*5=4.5 sess.run(tf.assign(step,10000)) # steps=10000 sess.run(tf.assign(v1,10)) # v1=10 sess.run(maintain_average) print(sess.run([v1,(1+10000)/(10+10000))=0.99,v1=0.99*4.5+0.01*10=4.555 sess.run(maintain_average) print(sess.run([v1,v1=0.99*4.555+0.01*10=4.60945 > [0.0,0.0] > [5.0,4.5] > [10.0,4.555] > [10.0,4.60945]
(编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |
- Java数学包,用于偏斜法线的逆累积分布以及泊松和
- java – 查找在eclipse中实现抽象类的所有具体类
- 从javax.swing.text尝试针对AbstractDocument.Un
- java – 如何在form,validation和ddl中重用field
- java – 如何在NotSerializableException中识别匿
- 浅谈java实现重载的方法
- Hibernate default_entity_mode属性:指定默认实
- 多线程 – 在Grand Central Dispatch中使用术语“
- java – 将侦听器传递给Android中的自定义片段
- 详解Spring MVC3返回JSON数据中文乱码问题解决