Weighted random sample in python

2019-01-06 18:50发布

问题:

I'm looking for a reasonable definition of a function weighted_sample that does not return just one random index for a list of given weights (which would be something like

def weighted_choice(weights, random=random):
    """ Given a list of weights [w_0, w_1, ..., w_n-1],
        return an index i in range(n) with probability proportional to w_i. """
    rnd = random.random() * sum(weights)
    for i, w in enumerate(weights):
        if w<0:
            raise ValueError("Negative weight encountered.")
        rnd -= w
        if rnd < 0:
            return i
    raise ValueError("Sum of weights is not positive")

to give a categorical distribution with constant weights) but a random sample of k of those, without replacement, just as random.sample behaves compared to random.choice.

Just as weighted_choice can be written as

lambda weights: random.choice([val for val, cnt in enumerate(weights)
    for i in range(cnt)])

weighted_sample could be written as

lambda weights, k: random.sample([val for val, cnt in enumerate(weights)
    for i in range(cnt)], k)

but I would like a solution that does not require me to unravel the weights into a (possibly huge) list.

Edit: If there are any nice algorithms that give me back a histogram/list of frequencies (in the same format as the argument weights) instead of a sequence of indices, that would also be very useful.

回答1:

From your code: ..

weight_sample_indexes = lambda weights, k: random.sample([val 
        for val, cnt in enumerate(weights) for i in range(cnt)], k)

.. I assume that weights are positive integers and by "without replacement" you mean without replacement for the unraveled sequence.

Here's a solution based on random.sample and O(log n) __getitem__:

import bisect
import random
from collections import Counter, Sequence

def weighted_sample(population, weights, k):
    return random.sample(WeightedPopulation(population, weights), k)

class WeightedPopulation(Sequence):
    def __init__(self, population, weights):
        assert len(population) == len(weights) > 0
        self.population = population
        self.cumweights = []
        cumsum = 0 # compute cumulative weight
        for w in weights:
            cumsum += w   
            self.cumweights.append(cumsum)  
    def __len__(self):
        return self.cumweights[-1]
    def __getitem__(self, i):
        if not 0 <= i < len(self):
            raise IndexError(i)
        return self.population[bisect.bisect(self.cumweights, i)]

Example

total = Counter()
for _ in range(1000):
    sample = weighted_sample("abc", [1,10,2], 5)
    total.update(sample)
print(sample)
print("Frequences %s" % (dict(Counter(sample)),))

# Check that values are sane
print("Total " + ', '.join("%s: %.0f" % (val, count * 1.0 / min(total.values()))
                           for val, count in total.most_common()))

Output

['b', 'b', 'b', 'c', 'c']
Frequences {'c': 2, 'b': 3}
Total b: 10, c: 2, a: 1


回答2:

What you want to create is a non-uniform random distribution. One bad way of doing this is to create a giant array with output symbols in proportion to the weights. So if a is 5 times more likely than b, you create an array with 5 times more a's than b's. This works fine for simple distributions where the weights are even multiples of each other. What if you wanted 99.99% a, and .01% b. You'd have to create 10000 slots.

There is a better way. All non-uniform distributions with N symbols can be decomposed into a series of n-1 binary distributions, each of which is equally likely.

So if you had such a decomponsition you'd first chose a binary distribution at random by generating a uniform random number from 1 - N-1

u32 dist = randInRange( 1, N-1 ); // generate a random number from 1 to N;

And then say the chosen distribution is a binary distribution with two symbols a and b, with a probability 0-alpha for a, and alpha-1 for b:

float f = randomFloat();
return ( f > alpha ) ? b : a;

How to decompose any non-uniform random distribution is a little more complex. Essentially you create N-1 'buckets'. Chose the symbols with the lowest probability and the one with the highest probability, and distribute their weights proportionally into the first binary distribution. Then delete the smallest symbol, and remove the amount of weight for the larger that was used to create this binary distribution. and repeat this process until you have no symbols left.

I can post c++ code for this if you want to go with this solution.



回答3:

If you construct the right data structure for random.sample() to operate on, you don't need to define a new function at all. Just use random.sample().

Here, __getitem__() is O(n) where n is the number of different items with weights you have. But it is compact in memory, requiring only the (weight, value) pairs be stored. I've used a similar class in practice, and it was plenty fast for my purposes. Note, this implementation assumes integer weights.

class SparseDistribution(object):
    _cached_length = None

    def __init__(self, weighted_items):
        # weighted items are (weight, value) pairs
        self._weighted_items = []
        for item in weighted_items:
            self.append(item)

    def append(self, weighted_item):
        self._weighted_items.append(weighted_item)
        self.__dict__.pop("_cached_length", None)

    def __len__(self):
        if self._cached_length is None:
            length = 0
            for w, v in self._weighted_items:
                length += w
            self._cached_length = length
        return self._cached_length

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError(index)
        for w, v in self._weighted_items:
            if index < w:
                return v
        raise Exception("Shouldn't have happened")

    def __iter__(self):
        for w, v in self._weighted_items:
            for _ in xrange(w):
                yield v

Then, we can use it:

import random

d = SparseDistribution([(5, "a"), (2, "b")])
d.append((3, "c"))

for num in (3, 5, 10, 11):
    try:
        print random.sample(d, num)
    except Exception as e:
        print "{}({!r})".format(type(e).__name__, str(e))

resulting in:

['a', 'a', 'b']
['b', 'a', 'c', 'a', 'b']
['a', 'c', 'a', 'c', 'a', 'b', 'a', 'a', 'b', 'c']
ValueError('sample larger than population')


回答4:

Since I am currently mostly interested in the histogram of the results, I thought of the following solution using numpy.random.hypergeometric (which unfortunately has bad behaviour for the border cases of ngood < 1, nbad < 1 and nsample < 1, so these cases need to be checked separately.)

def weighted_sample_histogram(frequencies, k, random=numpy.random):
    """ Given a sequence of absolute frequencies [w_0, w_1, ..., w_n-1],
    return a generator [s_0, s_1, ..., s_n-1] where the number s_i gives the
    absolute frequency of drawing the index i from an urn in which that index is
    represented by w_i balls, when drawing k balls without replacement. """
    W = sum(frequencies)
    if k > W:
        raise ValueError("Sum of absolute frequencies less than number of samples")
    for frequency in frequencies:
        if k < 1 or frequency < 1:
            yield 0
        else:
            W -= frequency
            if W < 1:
                good = k
            else:
                good = random.hypergeometric(frequency, W, k)
            k -= good
            yield good
    raise StopIteration

I gladly take any comments on how to improve this or why this might not be a good solution.

A python package implementing this (and other weighted random things) is now on http://github.com/Anaphory/weighted_choice.



回答5:

Sample is pretty fast. So unless you have a lot of megabytes to deal with, sample() should be fine.

On my machine it took 1.655 seconds to procduce 1000 samples out of 10000000 of length 100. And it took 12.98 seconds for traversing 100000 samples of length 100 from 10000000 elements.

from random import sample,random
from time import time

def generate(n1,n2,n3):
    w = [random() for x in range(n1)]

    print len(w)

    samples = list()
    for i in range(0,n2):
        s = sample(w,n3)
        samples.append(s)

    return samples

start = time()
size_set = 10**7
num_samples = 10**5
length_sample = 100
samples = generate(size_set,num_samples,length_sample)
end = time()

allsum=0
for row in samples:
    sum = reduce(lambda x, y: x+y,row)
    allsum+=sum

print 'sum of all elements',allsum

print '%f seconds for %i samples of %i length %i'%((end-start),size_set,num_sam\
ples,length_sample)