加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – 基于2d掩码数组的numpy 3d到2d变换

发布时间:2020-12-20 13:33:35 所属栏目:Python 来源:网络整理
导读:如果我有这样的ndarray: a = np.arange(27).reshape(3,3,3) aarray([[[ 0,1,2],[ 3,4,5],[ 6,7,8]],[[ 9,10,11],[12,13,14],[15,16,17]],[[18,19,20],[21,22,23],[24,25,26]]]) 我知道我可以使用np.max(axis = …)沿某个轴获得最大值: a.max(axis=2)array(
如果我有这样的ndarray:

>>> a = np.arange(27).reshape(3,3,3)
>>> a
array([[[ 0,1,2],[ 3,4,5],[ 6,7,8]],[[ 9,10,11],[12,13,14],[15,16,17]],[[18,19,20],[21,22,23],[24,25,26]]])

我知道我可以使用np.max(axis = …)沿某个轴获得最大值:

>>> a.max(axis=2)
array([[ 2,5,8],[11,14,17],[20,23,26]])

或者,我可以获得沿该轴的索引,这些索引对应于以下的最大值:

>>> indices = a.argmax(axis=2)
>>> indices
array([[2,2,[2,2]])

我的问题 – 给定数组索引和数组a,是否有一种优雅的方法来重现a.max(axis = 2)返回的数组?

这可能会奏效:

import itertools as it
import numpy as np
def apply_mask(field,indices):
    data = np.empty(indices.shape)

    #It seems highly likely that there is a more numpy-approved way to do this.
    idx = [range(i) for i in indices.shape]
    for idx_tup,zidx in zip(it.product(*idx),indices.flat):
        data[idx_tup] = field[idx_tup+(zidx,)]
    return data

但是,它似乎非常hacky /低效.它也不允许我使用除“最后”轴以外的任何轴.是否有一个numpy函数(或一些使用神奇的numpy索引)来使这项工作?天真的a [:,:,a.argmax(axis = 2)]不起作用.

更新:

似乎以下也有效(并且更好一些):

import numpy as np
def apply_mask(field,indices):
    data = np.empty(indices.shape)

    for idx_tup,zidx in np.ndenumerate(indices):
        data[idx_tup] = field[idx_tup+(zidx,)]

    return data

我想这样做是因为我想基于1个数组中的数据提取索引(通常使用argmax(axis = …))并使用这些索引从一堆其他数据中提取数据(等效形状)阵列.我愿意采用其他方法来实现这一目标(例如使用布尔掩码数组).但是,我喜欢使用这些“索引”数组的“安全性”.有了这个,我保证有适当数量的元素来创建一个新的数组,看起来像是通过3d字段的2d“切片”.

解决方法

这里有一些神奇的numpy索引可以做你想要的,但不幸的是它很难读.

def apply_mask(a,indices,axis):
    magic_index = [np.arange(i) for i in indices.shape]
    magic_index = np.ix_(*magic_index)
    magic_index = magic_index[:axis] + (indices,) + magic_index[axis:]
    return a[magic_index]

或同样不可读:

def apply_mask(a,axis):
    magic_index = np.ogrid[tuple(slice(i) for i in indices.shape)]
    magic_index.insert(axis,indices)
    return a[magic_index]

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读