加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – 在numpy数组中找到平衡点

发布时间:2020-12-20 12:08:13 所属栏目:Python 来源:网络整理
导读:考虑这个数组: a = np.array([1,2,3,4,1]) 我想得到均匀分割数组的元素,即元素之前的数组总和等于数组之后的数组之和.在这种情况下,第4个元素a [3]均匀地划分数组.有更快((numpy)的方式吗?或者我必须迭代所有元素? 期望的功能: f(a) = 3 解决方法 如果所
考虑这个数组:

a = np.array([1,2,3,4,1])

我想得到均匀分割数组的元素,即元素之前的数组总和等于数组之后的数组之和.在这种情况下,第4个元素a [3]均匀地划分数组.有更快((numpy)的方式吗?或者我必须迭代所有元素?

期望的功能:

f(a) =  3

解决方法

如果所有输入值都是非负的,那么最有效的方法之一似乎是建立一个累积和数组,然后二进制搜索它的位置,两边的总和是一半.但是,这样的二进制搜索错误也很容易.在尝试对所有边缘情况进行二进制搜索时,我结束了以下测试:

class SplitpointTest(unittest.TestCase):
    def testFloatRounding(self):
        # Due to rounding error,the cumulative sums for these inputs are
        # [1.1,3.3000000000000003,5.5,6.6]
        # and [0.1,0.7999999999999999,1.5,1.6]
        # Note that under default settings,numpy won't display
        # enough precision to see that.
        self.assertEquals(2,splitpoint([1.1,2.2,1e-20,1.1]))
        self.assertEquals(2,splitpoint([0.1,0.7,0.1]))

    def testIntRounding(self):
        self.assertEquals(1,splitpoint([1,1,1]))
    def testIntPrecision(self):
        self.assertEquals(2,splitpoint([2**60,2**60]))
    def testIntMax(self):
        self.assertEquals(
            2,splitpoint(numpy.array([40,23,63],dtype=numpy.int8))
        )

    def testIntZeros(self):
        self.assertEquals(
            4,splitpoint(numpy.array([0,1],dtype=int))
        )
    def testFloatZeros(self):
        self.assertEquals(
            4,dtype=float))
        )

在决定它不值得之前,我经历了以下版本:

def splitpoint(a):
    c = numpy.cumsum(a)
    return numpy.searchsorted(c,c[-1]/2)
    # Fails on [1,1]

def splitpoint(a):
    c = numpy.cumsum(a)
    return numpy.searchsorted(c,c[-1]/2.0)
    # Fails on [2**60,2**60]

def splitpoint(a):
    c = numpy.cumsum(a)
    if c.dtype.kind == 'f':
        # Floating-point input.
        return numpy.searchsorted(c,c[-1]/2.0)
    elif c.dtype.kind in ('i','u'):
        # Integer input.
        return numpy.searchsorted(c,(c[-1]+1)//2)
    else:
        # Probably an object dtype. No great options.
        return numpy.searchsorted(c,c[-1]/2.0)
    # Fails on numpy.array([63,dtype=int8)

def splitpoint(a):
    c = numpy.cumsum(a)
    if c.dtype.kind == 'f':
        # Floating-point input.
        return numpy.searchsorted(c,c[-1]//2 + c[-1]%2)
    else:
        # Probably an object dtype. No great options.
        return numpy.searchsorted(c,c[-1]/2.0)
    # Still fails the floating-point rounding and zeros tests.

如果我继续努力,我可能会得到这个工作,但这不值得. chw21的第二个解决方案,即基于明确最小化左右和之间的绝对差异的解决方案,更容易推理并且更普遍适用.通过添加a = numpy.asarray(a),它会传递所有上述测试用例以及以下测试,这些测试扩展了算法预期要处理的输入类型:

class SplitpointGeneralizedTest(unittest.TestCase):
    def testNegatives(self):
        self.assertEquals(2,splitpoint([-1,5,4]))
    def testComplex(self):
        self.assertEquals(2,splitpoint([1+1j,-5+2j,43,-4+3j]))
    def testObjectDtype(self):
        from fractions import Fraction
        from decimal import Decimal
        self.assertEquals(2,splitpoint(map(Fraction,[1.5,2.5,3.5,4])))
        self.assertEquals(2,splitpoint(map(Decimal,4])))

除非特别发现它太慢,否则我会选择chw21的第二个解决方案.在我测试它的略微修改的形式中,将是以下内容:

def splitpoint(a):
    a = np.asarray(a)
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    return np.argmin(np.abs(c1-c2))

我能看到的唯一缺陷是,如果输入有一个无符号的dtype并且没有完全拆分输入的索引,那么这个算法可能不会返回最接近拆分输入的索引,因为np.abs(c1-c2)对于无符号数据类型没有做正确的事情.如果没有拆分索引,则从未指定算法应该做什么,所以这种行为是可以接受的,尽管可能值得在注释中记录np.abs(c1-c2)和unsigned dtypes.如果我们希望索引最接近分割输入,我们可以以一些额外的运行时为代价来获取它:

def splitpoint(a):
    a = np.asarray(a)
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    if a.dtype.kind == 'u':
        # np.abs(c1-c2) doesn't work on unsigned ints
        absdiffs = np.where(c1>c2,c1-c2,c2-c1)
    else:
        # c1>c2 doesn't work on complex input.
        # We also use this case for other dtypes,since it's
        # probably faster.
        absdiffs = np.abs(c1-c2)
    return np.argmin(absdiffs)

当然,这是对此行为的测试,修改后的表单会传递,未修改的表单会失败:

class SplitpointUnsignedTest(unittest.TestCase):
    def testBestApproximation(self):
        self.assertEquals(1,splitpoint(numpy.array([5,5],dtype=numpy.uint32)))

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读