Weighted random numbers in MATLAB

2019-01-02 15:51发布

How to randomly pick up N numbers from a vector a with weight assigned to each number?

Let's say:

a = 1:3; % possible numbers
weight = [0.3 0.1 0.2]; % corresponding weights

In this case probability to pick up 1 should be 3 times higher than to pick up 2.

Sum of all weights can be anything.

4条回答
高级女魔头
2楼-- · 2019-01-02 16:36

TL;DR

For maximum performance, if you only need a singe sample, use

R = a( sum( (rand(1) >= cumsum(w./sum(w)))) + 1 );

and if you need multiple samples, use

[~, R] = histc(rand(N,1),cumsum([0;w(:)./sum(w)]));

Avoid randsample. Generating multiple samples upfront is three orders of magnitude faster than generating individual values.


Performance metrics

Since this showed up near the top of my Google search, I just wanted to add some performance metrics to show that the right solution will depend very much on the value of N and the requirements of the application. Also that changing the design of the application can dramatically increase performance.

For large N, or indeed N > 1:

a = 1:3;             % possible numbers
w = [0.3 0.1 0.2];   % corresponding weights
N = 100000000;       % number of values to generate

w_normalized = w / sum(w)  % normalised weights, for indication

fprintf('randsample:\n');
tic
R = randsample(a, N, true, w);
toc
tabulate(R)

fprintf('bsxfun:\n');
tic
R = a( sum( bsxfun(@ge, rand(N,1), cumsum(w./sum(w))), 2) + 1 );
toc
tabulate(R)

fprintf('histc:\n');
tic
[~, R] = histc(rand(N,1),cumsum([0;w(:)./sum(w)]));
toc
tabulate(R)

Results:

w_normalized =

    0.5000    0.1667    0.3333

randsample:
Elapsed time is 2.976893 seconds.
  Value    Count   Percent
      1    49997864     50.00%
      2    16670394     16.67%
      3    33331742     33.33%
bsxfun:
Elapsed time is 2.712315 seconds.
  Value    Count   Percent
      1    49996820     50.00%
      2    16665005     16.67%
      3    33338175     33.34%
histc:
Elapsed time is 2.078809 seconds.
  Value    Count   Percent
      1    50004044     50.00%
      2    16665508     16.67%
      3    33330448     33.33%

In this case, histc is fastest

However, in the case where maybe it is not possible to generate all N values up front, perhaps because the weights are updated on each iterations, i.e. N=1:

a = 1:3;             % possible numbers
w = [0.3 0.1 0.2];   % corresponding weights
I = 100000;          % number of values to generate

w_normalized = w / sum(w)  % normalised weights, for indication

R=zeros(N,1);

fprintf('randsample:\n');
tic
for i=1:I
    R(i) = randsample(a, 1, true, w);
end
toc
tabulate(R)

fprintf('cumsum:\n');
tic
for i=1:I
    R(i) = a( sum( (rand(1) >= cumsum(w./sum(w)))) + 1 );
end
toc
tabulate(R)

fprintf('histc:\n');
tic
for i=1:I
    [~, R(i)] = histc(rand(1),cumsum([0;w(:)./sum(w)]));
end
toc
tabulate(R)

Results:

    0.5000    0.1667    0.3333

randsample:
Elapsed time is 3.526473 seconds.
  Value    Count   Percent
      1    50437     50.44%
      2    16149     16.15%
      3    33414     33.41%
cumsum:
Elapsed time is 0.473207 seconds.
  Value    Count   Percent
      1    50018     50.02%
      2    16748     16.75%
      3    33234     33.23%
histc:
Elapsed time is 1.046981 seconds.
  Value    Count   Percent
      1    50134     50.13%
      2    16684     16.68%
      3    33182     33.18%

In this case, the custom cumsum approach (based on the bsxfun version) is fastest.

In any case, randsample certainly looks like a bad choice all round. It also goes to show that if an algorithm can be arranged to generate all random variables upfront then it will perform much better (note that there are three orders of magnitude less values generated in the N=1 case in a similar execution time).

Code is available here.

查看更多
后来的你喜欢了谁
3楼-- · 2019-01-02 16:38

amro gives a nice answer (that I rated up), but it will be highly intensive if you wish to generate many numbers from a large set. This is because the bsxfun operation can generate a huge array, which is then summed. For example, suppose I had a set of 10000 values to sample from, all with different weights? Now, generate 1000000 numbers from that sample.

This will take some work to do, since it will generate a 10000x1000000 array internally, with 10^10 elements in it. It will be a logical array, but even so, 10 gigabytes of ram must be allocated.

A better solution is to use histc. Thus...

a = 1:3
w = [.3 .1 .2];
N = 10;

[~,R] = histc(rand(1,N),cumsum([0;w(:)./sum(w)]));
R = a(R)
R =
     1     1     1     2     2     1     3     1     1     1

However, for a large problem of the size I suggested above, it is fast.

a = 1:10000;
w = rand(1,10000);
N = 1000000;

tic
[~,R] = histc(rand(1,N),cumsum([0;w(:)./sum(w)]));
R = a(R);
toc
Elapsed time is 0.120879 seconds.

Admittedly, my version takes 2 lines to write. The indexing operation must happen on a second line since it uses the second output of histc. Also note that I've used the ability of the new matlab release, with the tilde (~) operator as the first argument of histc. This causes that first argument to be immediately dumped in the bit bucket.

查看更多
浮光初槿花落
4楼-- · 2019-01-02 16:43
R = randsample([1 2 3], N, true, [0.3 0.1 0.2])

randsample is included in the Statistics Toolbox


Otherwise you can use some kind of roulette-wheel selection process. See this similar question (although not MATLAB specific). Here's my one-line implementation:

a = 1:3;             %# possible numbers
w = [0.3 0.1 0.2];   %# corresponding weights
N = 10;              %# how many numbers to generate

R = a( sum( bsxfun(@ge, rand(N,1), cumsum(w./sum(w))), 2) + 1 )

Explanation:

Consider the interval [0,1]. We assign for each element in the list (1:3) a sub-interval of length proportionate to the weight of each element; therefore 1 get and interval of length 0.3/(0.3+0.1+0.2), same for the others.

Now if we generate a random number with uniform distribution over [0,1], then any number in [0,1] has an equal probability of being picked, thus the sub-intervals' lengths determine the probability of the random number falling in each interval.

This matches what I'm doing above: pick a number X~U[0,1] (more like N numbers), then find which interval it falls into in a vectorized way..


You can check the results of the two techniques above by generating a large enough sequence N=1000:

>> tabulate( R )
  Value    Count   Percent
      1      511     51.10%
      2      160     16.00%
      3      329     32.90%

which more or less match the normalized weights w./sum(w) [0.5 0.16667 0.33333]

查看更多
萌妹纸的霸气范
5楼-- · 2019-01-02 16:47

Amro has a really nice answer for this topic. However, one might want a super-fast implementation to sample from huge PDFs where the domain might contain several thousands. For such scenarios, it might be tedious to use bsxfun and cumsum very frequently. Motivated from Gnovice's answer, it would make sense to implement roulette wheel algorithm with a run length encoding schema. I performed a benchmark with Amro's solution and new code:

%% Toy example: generate random numbers from an arbitrary PDF

a = 1:3;                                %# domain of PDF
w = [0.3 0.1 0.2];                      %# Probability Values (Weights)
N = 10000;                              %# Number of random generations

%Generate using roulette wheel + run length encoding
factor = 1 / min(w);                    %Compute min factor to assign 1 bin to min(PDF)
intW = int32(w * factor);               %Get replicator indexes for run length encoding
idxArr = zeros(1,sum(intW));            %Create index access array
idxArr([1 cumsum(intW(1:end-1))+1]) = 1;%Tag sample change indexes
sampTable = a(cumsum(idxArr));          %Create lookup table filled with samples
len = size(sampTable,2);

tic;
R = sampTable( uint32(randi([1 len],N,1)) );
toc;
tabulate(R);

Some evaluations of the code above for very large data where domain of PDF contain huge length.

a ~ 15000, n = 10000
Without table: Elapsed time is 0.006203 seconds.
With table:    Elapsed time is 0.003308 seconds.
ByteSize(sampTable) 796.23 kb

a ~ 15000, n = 100000
Without table: Elapsed time is 0.003510 seconds.
With table:    Elapsed time is 0.002823 seconds.

a ~ 35000, n = 10000
Without table: Elapsed time is 0.226990 seconds.
With table:    Elapsed time is 0.001328 seconds.
ByteSize(sampTable) 2.79 Mb

a ~ 35000  n = 100000
Without table: Elapsed time is 2.784713 seconds.
With table:    Elapsed time is 0.003452 seconds.

a ~ 35000  n = 1000000
Without table: bsxfun: out of memory
With table   : Elapsed time is 0.021093 seconds.

The idea is to create a run length encoding table where frequent values of the PDF are replicated more compared to non-frequent values. At the end of the day, we sample an index for weighted sample table, using uniform distribution, and use corresponding value.

It is memory intensive, but with this approach it is even possible to scale up to PDF lengths of hundred thousands. Hence access is super-fast.

查看更多
登录 后发表回答