python – 在numpy数组中找到平衡点
考虑这个数组:
a = np.array([1,2,3,4,1]) 我想得到均匀分割数组的元素,即元素之前的数组总和等于数组之后的数组之和.在这种情况下,第4个元素a [3]均匀地划分数组.有更快((numpy)的方式吗?或者我必须迭代所有元素? 期望的功能: f(a) = 3 解决方法
如果所有输入值都是非负的,那么最有效的方法之一似乎是建立一个累积和数组,然后二进制搜索它的位置,两边的总和是一半.但是,这样的二进制搜索错误也很容易.在尝试对所有边缘情况进行二进制搜索时,我结束了以下测试:
class SplitpointTest(unittest.TestCase): def testFloatRounding(self): # Due to rounding error,the cumulative sums for these inputs are # [1.1,3.3000000000000003,5.5,6.6] # and [0.1,0.7999999999999999,1.5,1.6] # Note that under default settings,numpy won't display # enough precision to see that. self.assertEquals(2,splitpoint([1.1,2.2,1e-20,1.1])) self.assertEquals(2,splitpoint([0.1,0.7,0.1])) def testIntRounding(self): self.assertEquals(1,splitpoint([1,1,1])) def testIntPrecision(self): self.assertEquals(2,splitpoint([2**60,2**60])) def testIntMax(self): self.assertEquals( 2,splitpoint(numpy.array([40,23,63],dtype=numpy.int8)) ) def testIntZeros(self): self.assertEquals( 4,splitpoint(numpy.array([0,1],dtype=int)) ) def testFloatZeros(self): self.assertEquals( 4,dtype=float)) ) 在决定它不值得之前,我经历了以下版本: def splitpoint(a): c = numpy.cumsum(a) return numpy.searchsorted(c,c[-1]/2) # Fails on [1,1] def splitpoint(a): c = numpy.cumsum(a) return numpy.searchsorted(c,c[-1]/2.0) # Fails on [2**60,2**60] def splitpoint(a): c = numpy.cumsum(a) if c.dtype.kind == 'f': # Floating-point input. return numpy.searchsorted(c,c[-1]/2.0) elif c.dtype.kind in ('i','u'): # Integer input. return numpy.searchsorted(c,(c[-1]+1)//2) else: # Probably an object dtype. No great options. return numpy.searchsorted(c,c[-1]/2.0) # Fails on numpy.array([63,dtype=int8) def splitpoint(a): c = numpy.cumsum(a) if c.dtype.kind == 'f': # Floating-point input. return numpy.searchsorted(c,c[-1]//2 + c[-1]%2) else: # Probably an object dtype. No great options. return numpy.searchsorted(c,c[-1]/2.0) # Still fails the floating-point rounding and zeros tests. 如果我继续努力,我可能会得到这个工作,但这不值得. chw21的第二个解决方案,即基于明确最小化左右和之间的绝对差异的解决方案,更容易推理并且更普遍适用.通过添加a = numpy.asarray(a),它会传递所有上述测试用例以及以下测试,这些测试扩展了算法预期要处理的输入类型: class SplitpointGeneralizedTest(unittest.TestCase): def testNegatives(self): self.assertEquals(2,splitpoint([-1,5,4])) def testComplex(self): self.assertEquals(2,splitpoint([1+1j,-5+2j,43,-4+3j])) def testObjectDtype(self): from fractions import Fraction from decimal import Decimal self.assertEquals(2,splitpoint(map(Fraction,[1.5,2.5,3.5,4]))) self.assertEquals(2,splitpoint(map(Decimal,4]))) 除非特别发现它太慢,否则我会选择chw21的第二个解决方案.在我测试它的略微修改的形式中,将是以下内容: def splitpoint(a): a = np.asarray(a) c1 = a.cumsum() c2 = a[::-1].cumsum()[::-1] return np.argmin(np.abs(c1-c2)) 我能看到的唯一缺陷是,如果输入有一个无符号的dtype并且没有完全拆分输入的索引,那么这个算法可能不会返回最接近拆分输入的索引,因为np.abs(c1-c2)对于无符号数据类型没有做正确的事情.如果没有拆分索引,则从未指定算法应该做什么,所以这种行为是可以接受的,尽管可能值得在注释中记录np.abs(c1-c2)和unsigned dtypes.如果我们希望索引最接近分割输入,我们可以以一些额外的运行时为代价来获取它: def splitpoint(a): a = np.asarray(a) c1 = a.cumsum() c2 = a[::-1].cumsum()[::-1] if a.dtype.kind == 'u': # np.abs(c1-c2) doesn't work on unsigned ints absdiffs = np.where(c1>c2,c1-c2,c2-c1) else: # c1>c2 doesn't work on complex input. # We also use this case for other dtypes,since it's # probably faster. absdiffs = np.abs(c1-c2) return np.argmin(absdiffs) 当然,这是对此行为的测试,修改后的表单会传递,未修改的表单会失败: class SplitpointUnsignedTest(unittest.TestCase): def testBestApproximation(self): self.assertEquals(1,splitpoint(numpy.array([5,5],dtype=numpy.uint32))) (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |