Random weighted selection in Java

2019-01-03 02:11发布

I want to choose a random item from a set, but the chance of choosing any item should be proportional to the associated weight

Example inputs:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1

So, if I have 4 possible items, the chance of getting any one item without weights would be 1 in 4.

In this case, a user should be 10 times more likely to get the sword of misery than the triple-edged sword.

How do I make a weighted random selection in Java?

6条回答
别忘想泡老子
2楼-- · 2019-01-03 02:34

There is now a class for this in Apache Commons: EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution(itemWeights).sample();

where itemWeights is a List<Pair<Item,Double>>, like (assuming Item interface in Arne's answer):

List<Pair<Item,Double>> itemWeights = Collections.newArrayList();
for (Item i : itemSet) {
    itemWeights.add(new Pair(i, i.getWeight()));
}

or in Java 8:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());

Note: Pair here needs to be org.apache.commons.math3.util.Pair, not org.apache.commons.lang3.tuple.Pair.

查看更多
放我归山
3楼-- · 2019-01-03 02:35

I would use a NavigableMap

public class RandomCollection<E> {
    private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
    private final Random random;
    private double total = 0;

    public RandomCollection() {
        this(new Random());
    }

    public RandomCollection(Random random) {
        this.random = random;
    }

    public RandomCollection<E> add(double weight, E result) {
        if (weight <= 0) return this;
        total += weight;
        map.put(total, result);
        return this;
    }

    public E next() {
        double value = random.nextDouble() * total;
        return map.higherEntry(value).getValue();
    }
}

Say I have a list of animals dog, cat, horse with probabilities as 40%, 35%, 25% respectively

RandomCollection<String> rc = new RandomCollection<>()
                              .add(40, "dog").add(35, "cat").add(25, "horse");

for (int i = 0; i < 10; i++) {
    System.out.println(rc.next());
} 
查看更多
爱情/是我丢掉的垃圾
4楼-- · 2019-01-03 02:43

You will not find a framework for this kind of problem, as the requested functionality is nothing more then a simple function. Do something like this:

interface Item {
    double getWeight();
}

class RandomItemChooser {
    public Item chooseOnWeight(List<Item> items) {
        double completeWeight = 0.0;
        for (Item item : items)
            completeWeight += item.getWeight();
        double r = Math.random() * completeWeight;
        double countWeight = 0.0;
        for (Item item : items) {
            countWeight += item.getWeight();
            if (countWeight >= r)
                return item;
        }
        throw new RuntimeException("Should never be shown.");
    }
}
查看更多
一夜七次
5楼-- · 2019-01-03 02:45

If you need to remove elements after choosing you can use another solution. Add all the elements into a 'LinkedList', each element must be added as many times as it weight is, then use Collections.shuffle() which, according to JavaDoc

Randomly permutes the specified list using a default source of randomness. All permutations occur with approximately equal likelihood.

Finally, get and remove elements using pop() or removeFirst()

Map<String, Integer> map = new HashMap<String, Integer>() {{
    put("Five", 5);
    put("Four", 4);
    put("Three", 3);
    put("Two", 2);
    put("One", 1);
}};

LinkedList<String> list = new LinkedList<>();

for (Map.Entry<String, Integer> entry : map.entrySet()) {
    for (int i = 0; i < entry.getValue(); i++) {
        list.add(entry.getKey());
    }
}

Collections.shuffle(list);

int size = list.size();
for (int i = 0; i < size; i++) {
    System.out.println(list.pop());
}
查看更多
SAY GOODBYE
6楼-- · 2019-01-03 02:49
public class RandomCollection<E> {
  private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  private double total = 0;

  public void add(double weight, E result) {
    if (weight <= 0 || map.containsValue(result))
      return;
    total += weight;
    map.put(total, result);
  }

  public E next() {
    double value = ThreadLocalRandom.current().nextDouble() * total;
    return map.ceilingEntry(value).getValue();
  }
}
查看更多
劳资没心,怎么记你
7楼-- · 2019-01-03 03:01

Use an alias method

If you're gonna roll a lot of times (as in a game), you should use an alias method.

The code below is rather long implementation of such an alias method, indeed. But this is because of the initialization part. The retrieval of elements is very fast (see the next and the applyAsInt methods they don't loop).

Usage

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;

Random random = new Random();

RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);

Implementation

This implementation:

  • uses Java 8;
  • is designed to be as fast as possible (well, at least, I tried to do so using micro-benchmarking);
  • is totally thread-safe (keep one Random in each thread for maximum performance, use ThreadLocalRandom?);
  • fetches elements in O(1), unlike what you mostly find on the internet or on StackOverflow, where naive implementations run in O(n) or O(log(n));
  • keeps the items independant from their weight, so an item can be assigned various weights in different contexts.

Anyways, here's the code. (Note that I maintain an up to date version of this class.)

import static java.util.Objects.requireNonNull;

import java.util.*;
import java.util.function.*;

public final class RandomSelector<T> {

  public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
      throws IllegalArgumentException {
    requireNonNull(elements, "elements must not be null");
    requireNonNull(weighter, "weighter must not be null");
    if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }

    // Array is faster than anything. Use that.
    int size = elements.size();
    T[] elementArray = elements.toArray((T[]) new Object[size]);

    double totalWeight = 0d;
    double[] discreteProbabilities = new double[size];

    // Retrieve the probabilities
    for (int i = 0; i < size; i++) {
      double weight = weighter.applyAsDouble(elementArray[i]);
      if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
      discreteProbabilities[i] = weight;
      totalWeight += weight;
    }
    if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }

    // Normalize the probabilities
    for (int i = 0; i < size; i++) {
      discreteProbabilities[i] /= totalWeight;
    }
    return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
  }

  private final T[] elements;
  private final ToIntFunction<Random> selection;

  private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
    this.elements = elements;
    this.selection = selection;
  }

  public T next(Random random) {
    return elements[selection.applyAsInt(random)];
  }

  private static class RandomWeightedSelection implements ToIntFunction<Random> {
    // Alias method implementation O(1)
    // using Vose's algorithm to initialize O(n)

    private final double[] probabilities;
    private final int[] alias;

    RandomWeightedSelection(double[] probabilities) {
      int size = probabilities.length;

      double average = 1.0d / size;
      int[] small = new int[size];
      int smallSize = 0;
      int[] large = new int[size];
      int largeSize = 0;

      // Describe a column as either small (below average) or large (above average).
      for (int i = 0; i < size; i++) {
        if (probabilities[i] < average) {
          small[smallSize++] = i;
        } else {
          large[largeSize++] = i;
        }
      }

      // For each column, saturate a small probability to average with a large probability.
      while (largeSize != 0 && smallSize != 0) {
        int less = small[--smallSize];
        int more = large[--largeSize];
        probabilities[less] = probabilities[less] * size;
        alias[less] = more;
        probabilities[more] += probabilities[less] - average;
        if (probabilities[more] < average) {
          small[smallSize++] = more;
        } else {
          large[largeSize++] = more;
        }
      }

      // Flush unused columns.
      while (smallSize != 0) {
        probabilities[small[--smallSize]] = 1.0d;
      }
      while (largeSize != 0) {
        probabilities[large[--largeSize]] = 1.0d;
      }
    }

    @Override public int applyAsInt(Random random) {
      // Call random once to decide which column will be used.
      int column = random.nextInt(probabilities.length);

      // Call random a second time to decide which will be used: the column or the alias.
      if (random.nextDouble() < probabilities[column]) {
        return column;
      } else {
        return alias[column];
      }
    }
  }
}
查看更多
登录 后发表回答