加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python-在Tensorflow中将自定义渐变定义为类方法

发布时间:2020-12-17 17:38:01 所属栏目:Python 来源:网络整理
导读:我需要将方法定义为自定义渐变,如下所示: class CustGradClass: def __init__(self): pass @tf.custom_gradient def f(self,x): fx = x def grad(dy): return dy * 1 return fx,grad 我收到以下错误: ValueError: Attempt to convert a value ( main .Cust

我需要将方法定义为自定义渐变,如下所示:

class CustGradClass:

    def __init__(self):
        pass

    @tf.custom_gradient
    def f(self,x):
      fx = x
      def grad(dy):
        return dy * 1
      return fx,grad

我收到以下错误:

ValueError: Attempt to convert a value (<main.CustGradClass object at 0x12ed91710>) with an unsupported type () to a Tensor.

原因是自定义渐变接受函数f(* x),其中x是张量序列.传递的第一个参数是对象本身,即自我.

从documentation开始:

f: function f(*x) that returns a tuple (y,grad_fn) where:
x is a sequence of Tensor inputs to the function.
y is a Tensor or sequence of Tensor outputs of applying TensorFlow operations in f to x.
grad_fn is a function with the signature g(*grad_ys)

我该如何运作?我需要继承一些python tensorflow类吗?

我正在使用TF版本1.12.0和渴望模式.

最佳答案
这是一种可能的简单解决方法:

import tensorflow as tf

class CustGradClass:

    def __init__(self):
        self.f = tf.custom_gradient(lambda x: CustGradClass._f(self,x))

    @staticmethod
    def _f(self,x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx,grad

with tf.Graph().as_default(),tf.Session() as sess:
    x = tf.constant(1.0)
    c = CustGradClass()
    y = c.f(x)
    print(tf.gradients(y,x))
    # [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]

编辑:

如果您想在不同的类上多次执行此操作,或者只想使用更可重用的解决方案,则可以使用如下所示的装饰器:

import functools
import tensorflow as tf

def tf_custom_gradient_method(f):
    @functools.wraps(f)
    def wrapped(self,*args,**kwargs):
        if not hasattr(self,'_tf_custom_gradient_wrappers'):
            self._tf_custom_gradient_wrappers = {}
        if f not in self._tf_custom_gradient_wrappers:
            self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a,**kw: f(self,*a,**kw))
        return self._tf_custom_gradient_wrappers[f](*args,**kwargs)
    return wrapped

然后,您可以这样做:

class CustGradClass:

    def __init__(self):
        pass

    @tf_custom_gradient_method
    def f(self,grad

    @tf_custom_gradient_method
    def f2(self,x):
        fx = x * 2
        def grad(dy):
            return dy * 2
        return fx,grad

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读