python – 在大量数组中搜索最近的数组
我需要找到最接近的句子.
我有一个句子数组和一个用户句子,我需要找到最接近用户的数组句子元素. 我使用word2vec以向量的形式呈现每个句子: def get_avg_vector(word_list,model_w2v,size=500): sum_vec = np.zeros(shape = (1,size)) count = 0 for w in word_list: if w in model_w2v and w != '': sum_vec += model_w2v[w] count +=1 if count == 0: return sum_vec else: return sum_vec / count + 1 结果,数组元素如下所示: array([[ 0.93162371,0.95618944,0.98519795,0.98580566,0.96563747,0.97070891,0.99079191,1.01572807,1.00631016,1.07349398,1.02079309,1.0064849,0.99179418,1.02865136,1.02610303,1.02909719,0.99350413,0.97481178,0.97980362,0.98068508,1.05657591,0.97224562,0.99778703,0.97888296,1.01650529,1.0421448,0.98731804,0.98349052,0.93752996,0.98205837,1.05691232,0.99914532,1.02040555,0.99427229,1.01193818,0.94922226,0.9818139,1.03955,1.01252615,1.01402485,... 0.98990598,0.99576604,1.0903802,1.02493086,0.97395976,0.95563786,1.00538653,1.0036294,0.97220088,1.04822631,1.02806122,0.95402776,1.0048053,0.97677222,0.97830801]]) 我将用户的句子也表示为向量,并且我计算最接近它的元素是这样的: %%cython from scipy.spatial.distance import euclidean def compute_dist(v,list_sentences): dist_dict = {} for key,val in list_sentences.items(): dist_dict[key] = euclidean(v,val) return sorted(dist_dict.items(),key=lambda x: x[1])[0][0] 上述方法中的list_sentences是一个字典,其中键是句子的文本表示,值是矢量. 这需要很长时间,因为我有超过6000万句话. 我会很感激任何建议. 解决方法
6000万个句子向量的初始计算基本上是你需要支付一次的固定成本.对于单个用户提供的查询语句,我假设您主要关心每次后续查找的时间.
使用numpy本机数组操作可以加快距离计算,而不是在Python循环中进行自己的单独计算. (它可以使用优化的代码批量处理.) 但首先你要用一个真正的numpy数组替换list_sentences,只能通过array-index访问. (如果你有其他键/文本需要与每个插槽关联,你可以在其他地方使用某些字典或列表.) 让我们假设你已经以任何自然的方式完成了这项工作,现在有了array_sentences,一个6000万×500维的numpy数组,每行有一个句子平均向量. 然后,获得一个充满距离的数组的单线方式是作为6000万候选者和1个查询中的每一个之间的差异的向量长度(“标准”)(其给出每个6,000万个条目答案)差异): dists = np.linalg.norm(array_sentences - v) 另一种方法是使用numpy实用程序函数cdist()来计算每对输入集合之间的距离.在这里,您的第一个集合只是一个查询向量v(但如果您一次批量处理,一次提供多个查询可以提供额外的轻微加速): dists = np.linalg.cdists(array[v],array_sentences) (请注意,这样的矢量比较通常使用余弦距离/余弦相似度而不是欧几里德距离.如果切换到那个,你可能正在做其他的规范/点积而不是上面的第一个选项,或者使用metric =’ cdine()的余弦’选项.) 一旦你在numpy数组中拥有所有距离,使用numpy-native排序选项可能比使用Python sorted()更快.例如,numpy的间接排序argsort(),它只返回已排序的索引(从而避免移动所有的矢量坐标),因为您只想知道哪些项是最佳匹配.例如: sorted_indexes = argsort(dists) best_index = sorted_indexes[0] 如果您需要将该int索引转换回您的其他键/文本,您将使用自己的dict / list来记住插槽到键的关系. 通过与所有候选人进行比较,所有这些仍然给出了完全正确的结果,即使(即使在最佳状态下做得好)仍然是耗时的. 有一些方法可以获得更快的结果,基于对整个候选者的预建索引 – 但这样的索引在高维空间(如500维空间)中变得非常棘手.他们经常以完全准确的结果进行权衡,以获得更快(也就是说,他们为“最接近的1”或“最接近的N”返回的内容会有一些错误,但通常不会有太多错误.)有关此类库的示例,请参阅Spotify’s ANNOY或Facebook’s FAISS. (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |