python中的加权随机样本
发布时间:2020-12-20 10:34:25 所属栏目:Python 来源:网络整理
导读:我正在寻找一个函数weighted_sample的合理定义,它不会为给定权重列表返回一个随机索引(这类似于 def weighted_choice(weights,random=random): """ Given a list of weights [w_0,w_1,...,w_n-1],return an index i in range(n) with probability proportion
我正在寻找一个函数weighted_sample的合理定义,它不会为给定权重列表返回一个随机索引(这类似于
def weighted_choice(weights,random=random): """ Given a list of weights [w_0,w_1,...,w_n-1],return an index i in range(n) with probability proportional to w_i. """ rnd = random.random() * sum(weights) for i,w in enumerate(weights): if w<0: raise ValueError("Negative weight encountered.") rnd -= w if rnd < 0: return i raise ValueError("Sum of weights is not positive") 给出一个具有恒定权重的分类分布)但是随机抽样的那些k,没有替换,就像random.sample行为与random.choice相比. 就像weighted_choice可以写成 lambda weights: random.choice([val for val,cnt in enumerate(weights) for i in range(cnt)]) weighted_sample可以写成 lambda weights,k: random.sample([val for val,cnt in enumerate(weights) for i in range(cnt)],k) 但我想要一个解决方案,不需要我将权重解析为(可能是巨大的)列表. 编辑:如果有任何好的算法可以返回一个直方图/频率列表(与参数权重的格式相同)而不是一系列索引,这也是非常有用的. 解决方法
从你的代码:..
weight_sample_indexes = lambda weights,k: random.sample([val for val,cnt in enumerate(weights) for i in range(cnt)],k) ..我认为权重是正整数,而“没有替换”你的意思是没有替换解开的序列. 这是一个基于random.sample和O(log n)__getitem__的解决方案: import bisect import random from collections import Counter,Sequence def weighted_sample(population,weights,k): return random.sample(WeightedPopulation(population,weights),k) class WeightedPopulation(Sequence): def __init__(self,population,weights): assert len(population) == len(weights) > 0 self.population = population self.cumweights = [] cumsum = 0 # compute cumulative weight for w in weights: cumsum += w self.cumweights.append(cumsum) def __len__(self): return self.cumweights[-1] def __getitem__(self,i): if not 0 <= i < len(self): raise IndexError(i) return self.population[bisect.bisect(self.cumweights,i)] 例 total = Counter() for _ in range(1000): sample = weighted_sample("abc",[1,10,2],5) total.update(sample) print(sample) print("Frequences %s" % (dict(Counter(sample)),)) # Check that values are sane print("Total " + ','.join("%s: %.0f" % (val,count * 1.0 / min(total.values())) for val,count in total.most_common())) 产量 ['b','b','c','c'] Frequences {'c': 2,'b': 3} Total b: 10,c: 2,a: 1 (编辑:李大同) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |