python – 如何覆盖NumPy的ndarray和我的类型之间的比较?
发布时间:2020-12-20 13:32:49 所属栏目:Python 来源:网络整理
导读:在NumPy中,可以使用__array_priority__属性来控制作用于ndarray和用户定义类型的二元运算符.例如: class Foo(object): def __radd__(self,lhs): return 0 __array_priority__ = 100a = np.random.random((100,100))b = Foo()a + b # calls b.__radd__(a) -
在NumPy中,可以使用__array_priority__属性来控制作用于ndarray和用户定义类型的二元运算符.例如:
class Foo(object): def __radd__(self,lhs): return 0 __array_priority__ = 100 a = np.random.random((100,100)) b = Foo() a + b # calls b.__radd__(a) -> 0 然而,同样的事情似乎不适用于比较运算符.例如,如果我将以下行添加到Foo,那么它永远不会从表达式a< b: def __rlt__(self,lhs): return 0 我意识到__rlt__并不是真正的Python特殊名称,但我认为它可能有用.我尝试了所有的__lt __,__ le __,__ eq __,__ ne__,__ ge __,__ gt__,有或没有前面的r,加上__cmp__,但我永远无法让NumPy调用它们中的任何一个. 这些比较可以被覆盖吗? UPDATE 为了避免混淆,这里有一个更长的描述NumPy的行为.对于初学者来说,这是NumPy指南中所说的内容: If the ufunc has 2 inputs and 1 output and the second input is an Object array then a special-case check is performed so that NotImplemented is returned if the second input is not an ndarray,has the array priority attribute,and has an r<op> special method. 我认为这是制定工作的规则.这是一个例子: import numpy as np a = np.random.random((2,2)) class Bar0(object): def __add__(self,rhs): return 0 def __radd__(self,rhs): return 1 b = Bar0() print a + b # Calls __radd__ four times,returns an array # [[1 1] # [1 1]] class Bar1(object): def __add__(self,rhs): return 1 __array_priority__ = 100 b = Bar1() print a + b # Calls __radd__ once,returns 1 # 1 如您所见,在没有__array_priority__的情况下,NumPy将用户定义的对象解释为标量类型,并在数组中的每个位置应用该操作.那不是我想要的.我的类型是数组(但不应该从ndarray派生). 这是一个较长的示例,显示了在定义所有比较方法时如何失败: class Foo(object): def __cmp__(self,rhs): return 0 def __lt__(self,rhs): return 1 def __le__(self,rhs): return 2 def __eq__(self,rhs): return 3 def __ne__(self,rhs): return 4 def __gt__(self,rhs): return 5 def __ge__(self,rhs): return 6 __array_priority__ = 100 b = Foo() print a < b # Calls __cmp__ four times,returns an array # [[False False] # [False False]] 解决方法
看起来我自己可以回答这个问题. np.set_numeric_ops可以使用如下:
class Foo(object): def __lt__(self,rhs): return 0 def __le__(self,rhs): return 1 def __eq__(self,rhs): return 2 def __ne__(self,rhs): return 3 def __gt__(self,rhs): return 4 def __ge__(self,rhs): return 5 __array_priority__ = 100 def override(name): def ufunc(x,y): if isinstance(y,Foo): return NotImplemented return np.getattr(name)(x,y) return ufunc np.set_numeric_ops( ** { ufunc : override(ufunc) for ufunc in ( "less","less_equal","equal","not_equal","greater_equal","greater" ) } ) a = np.random.random((2,2)) b = Foo() print a < b # 4 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |