加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 编程开发 > Python > 正文

python – 在大量数组中搜索最近的数组

发布时间:2020-12-20 11:03:47 所属栏目:Python 来源:网络整理
导读:我需要找到最接近的句子. 我有一个句子数组和一个用户句子,我需要找到最接近用户的数组句子元素. 我使用word2vec以向量的形式呈现每个句子: def get_avg_vector(word_list,model_w2v,size=500): sum_vec = np.zeros(shape = (1,size)) count = 0 for w in w
我需要找到最接近的句子.
我有一个句子数组和一个用户句子,我需要找到最接近用户的数组句子元素.

我使用word2vec以向量的形式呈现每个句子:

def get_avg_vector(word_list,model_w2v,size=500):
    sum_vec = np.zeros(shape = (1,size))
    count = 0

    for w in word_list:
        if w in model_w2v and w != '':
            sum_vec += model_w2v[w]
            count +=1
    if count == 0:
        return sum_vec
    else:
        return sum_vec / count + 1

结果,数组元素如下所示:

array([[ 0.93162371,0.95618944,0.98519795,0.98580566,0.96563747,0.97070891,0.99079191,1.01572807,1.00631016,1.07349398,1.02079309,1.0064849,0.99179418,1.02865136,1.02610303,1.02909719,0.99350413,0.97481178,0.97980362,0.98068508,1.05657591,0.97224562,0.99778703,0.97888296,1.01650529,1.0421448,0.98731804,0.98349052,0.93752996,0.98205837,1.05691232,0.99914532,1.02040555,0.99427229,1.01193818,0.94922226,0.9818139,1.03955,1.01252615,1.01402485,...
         0.98990598,0.99576604,1.0903802,1.02493086,0.97395976,0.95563786,1.00538653,1.0036294,0.97220088,1.04822631,1.02806122,0.95402776,1.0048053,0.97677222,0.97830801]])

我将用户的句子也表示为向量,并且我计算最接近它的元素是这样的:

%%cython
from scipy.spatial.distance import euclidean

def compute_dist(v,list_sentences):
    dist_dict = {}

    for key,val in list_sentences.items():
        dist_dict[key] = euclidean(v,val)

    return sorted(dist_dict.items(),key=lambda x: x[1])[0][0]

上述方法中的list_sentences是一个字典,其中键是句子的文本表示,值是矢量.

这需要很长时间,因为我有超过6000万句话.
我怎样才能加快,优化这个过程?

我会很感激任何建议.

解决方法

6000万个句子向量的初始计算基本上是你需要支付一次的固定成本.对于单个用户提供的查询语句,我假设您主要关心每次后续查找的时间.

使用numpy本机数组操作可以加快距离计算,而不是在Python循环中进行自己的单独计算. (它可以使用优化的代码批量处理.)

但首先你要用一个真正的numpy数组替换list_sentences,只能通过array-index访问. (如果你有其他键/文本需要与每个插槽关联,你可以在其他地方使用某些字典或列表.)

让我们假设你已经以任何自然的方式完成了这项工作,现在有了array_sentences,一个6000万×500维的numpy数组,每行有一个句子平均向量.

然后,获得一个充满距离的数组的单线方式是作为6000万候选者和1个查询中的每一个之间的差异的向量长度(“标准”)(其给出每个6,000万个条目答案)差异):

dists = np.linalg.norm(array_sentences - v)

另一种方法是使用numpy实用程序函数cdist()来计算每对输入集合之间的距离.在这里,您的第一个集合只是一个查询向量v(但如果您一次批量处理,一次提供多个查询可以提供额外的轻微加速):

dists = np.linalg.cdists(array[v],array_sentences)

(请注意,这样的矢量比较通常使用余弦距离/余弦相似度而不是欧几里德距离.如果切换到那个,你可能正在做其他的规范/点积而不是上面的第一个选项,或者使用metric =’ cdine()的余弦’选项.)

一旦你在numpy数组中拥有所有距离,使用numpy-native排序选项可能比使用Python sorted()更快.例如,numpy的间接排序argsort(),它只返回已排序的索引(从而避免移动所有的矢量坐标),因为您只想知道哪些项是最佳匹配.例如:

sorted_indexes = argsort(dists)
best_index = sorted_indexes[0]

如果您需要将该int索引转换回您的其他键/文本,您将使用自己的dict / list来记住插槽到键的关系.

通过与所有候选人进行比较,所有这些仍然给出了完全正确的结果,即使(即使在最佳状态下做得好)仍然是耗时的.

有一些方法可以获得更快的结果,基于对整个候选者的预建索引 – 但这样的索引在高维空间(如500维空间)中变得非常棘手.他们经常以完全准确的结果进行权衡,以获得更快(也就是说,他们为“最接近的1”或“最接近的N”返回的内容会有一些错误,但通常不会有太多错误.)有关此类库的示例,请参阅Spotify’s ANNOY或Facebook’s FAISS.

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读