Weighted randomness in Java [duplicate]

2020-01-24 07:52发布

问题:

In Java, given n Items, each with weight w, how does one choose a random Item from the collection with a chance equal to w?

Assume each weight is a double from 0.0 to 1.0, and that the weights in the collection sum to 1. Item.getWeight() returns the Item's weight.

回答1:

Item[] items = ...;

// Compute the total weight of all items together
double totalWeight = 0.0d;
for (Item i : items)
{
    totalWeight += i.getWeight();
}
// Now choose a random item
int randomIndex = -1;
double random = Math.random() * totalWeight;
for (int i = 0; i < items.length; ++i)
{
    random -= items[i].getWeight();
    if (random <= 0.0d)
    {
        randomIndex = i;
        break;
    }
}
Item myRandomItem = items[randomIndex];


回答2:

One elegant way would be to sample an exponential distribution http://en.wikipedia.org/wiki/Exponential_distribution where the weights will be the rate of the distribution (lambda). Finally, you simply select the smallest sampled value.

In Java this looks like this:

public static <E> E getWeightedRandom(Map<E, Double> weights, Random random) {
    E result = null;
    double bestValue = Double.MAX_VALUE;

    for (E element : weights.keySet()) {
        double value = -Math.log(random.nextDouble()) / weights.get(element);

        if (value < bestValue) {
            bestValue = value;
            result = element;
        }
    }

    return result;
}

I am not sure whether this is more efficient than the other approaches, but if execution time is not the issue, it is a nicely looking solution.

And this is the same idea using Java 8 and Streams:

public static <E> E getWeightedRandomJava8(Stream<Entry<E, Double>> weights, Random random) {
    return weights
        .map(e -> new SimpleEntry<E,Double>(e.getKey(),-Math.log(random.nextDouble()) / e.getValue()))
        .min((e0,e1)-> e0.getValue().compareTo(e1.getValue()))
        .orElseThrow(IllegalArgumentException::new).getKey();
}

You can obtain the input weights stream for instance from a map by converting it with .entrySet().stream().



回答3:

TreeMap does already do all the work for you.

Create a TreeMap. Create weights based on your method of choice. Add the weights beginning with 0.0 while adding the weight of the last element to your running weight counter.

i.e. (Scala):

var count = 0.0  
for { object <- MyObjectList } { //Just any iterator over all objects 
  map.insert(count, object) 
  count += object.weight
}

Then you just have to generate rand = new Random(); num = rand.nextDouble() * count to get a valid number.

map.to(num).last  // Scala
map.floorKey(num) // Java

will give you the random weighted item.

For smaller amounts of buckets also possible: Create an array of i.e. 100,000 Int's and distribute the number of the bucket based on the weight across the fields. Then you create a random Integer between 0 and 100,000-1 and you immediately get the bucket-number back.



回答4:

If you want runtime selection efficiency then taking a little more time on the setup would probably be best. Here is one possible solution. It has more code but guarantees log(n) selection.

WeightedItemSelector Implements selection of a random object from a collection of weighted objects. Selection is guaranteed to run in log(n) time.

public class WeightedItemSelector<T> {
    private final Random rnd = new Random();
    private final TreeMap<Object, Range<T>> ranges = new TreeMap<Object, Range<T>>();
    private int rangeSize; // Lowest integer higher than the top of the highest range.

    public WeightedItemSelector(List<WeightedItem<T>> weightedItems) {
        int bottom = 0; // Increments by size of non zero range added as ranges grows.

        for(WeightedItem<T> wi : weightedItems) {
            int weight = wi.getWeight();
            if(weight > 0) {
                int top = bottom + weight - 1;
                Range<T> r = new Range<T>(bottom, top, wi);
                if(ranges.containsKey(r)) {
                    Range<T> other = ranges.get(r);
                    throw new IllegalArgumentException(String.format("Range %s conflicts with range %s", r, other));
                }
                ranges.put(r, r);
                bottom = top + 1;
            }
        }
        rangeSize = bottom; 
    }

    public WeightedItem<T> select() {
        Integer key = rnd.nextInt(rangeSize);
        Range<T> r = ranges.get(key);
        if(r == null)
            return null;
        return r.weightedItem;
    }
}

Range Implements range selection to leverage TreeMap selection.

class  Range<T> implements Comparable<Object>{
    final int bottom;
    final int top;
    final WeightedItem<T> weightedItem;
    public Range(int bottom, int top, WeightedItem<T> wi) {
        this.bottom = bottom;
        this.top = top;
        this.weightedItem = wi;
    }

    public WeightedItem<T> getWeightedItem() {
        return weightedItem;
    }

    @Override
    public int compareTo(Object arg0) {
        if(arg0 instanceof Range<?>) {
            Range<?> other = (Range<?>) arg0;
            if(this.bottom > other.top)
                return 1;
            if(this.top < other.bottom)
                return -1;
            return 0; // overlapping ranges are considered equal.
        } else if (arg0 instanceof Integer) {
            Integer other = (Integer) arg0;
            if(this.bottom > other.intValue())
                return 1;
            if(this.top < other.intValue())
                return -1;
            return 0;
        }
        throw new IllegalArgumentException(String.format("Cannot compare Range objects to %s objects.", arg0.getClass().getName()));
    }

    /* (non-Javadoc)
     * @see java.lang.Object#toString()
     */
    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("{\"_class\": Range {\"bottom\":\"").append(bottom).append("\", \"top\":\"").append(top)
                .append("\", \"weightedItem\":\"").append(weightedItem).append("}");
        return builder.toString();
    }
}

WeightedItem simply encapsulates an item to be selected.

public class WeightedItem<T>{
    private final int weight;
    private final T item;
    public WeightedItem(int weight, T item) {
        this.item = item;
        this.weight = weight;
    }

    public T getItem() {
        return item;
    }

    public int getWeight() {
        return weight;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#toString()
     */
    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("{\"_class\": WeightedItem {\"weight\":\"").append(weight).append("\", \"item\":\"")
                .append(item).append("}");
        return builder.toString();
    }
}


回答5:

  1. Give some arbitrary ordering to items... (i1, i2, ..., in)... with weights w1, w2, ..., wn.
  2. Choose a random number between 0 and 1 (with sufficient granularity, by using any randomization function and appropriate scaling). Call this r0.
  3. Let j = 1
  4. Subtract wj from your r(j-1) to get rj. If rj <= 0, then you select item ij. Otherwise, increment j and repeat.

I think I've done it like that before... but there are probably more efficient ways to do this.



回答6:

Below is a randomizer that maintains precision in proportions as well:

public class WeightedRandomizer
{
    private final Random randomizer;

    public WeightedRandomizer(Random randomizer)
    {
        this.randomizer = randomizer;
    }

    public IWeighable getRandomWeighable(List<IWeighable> weighables)
    {
        double totalWeight = 0.0;
        long totalSelections = 1;
        List<IWeighable> openWeighables = new ArrayList<>();

        for (IWeighable weighable : weighables)
        {
            totalWeight += weighable.getAllocation();
            totalSelections += weighable.getNumSelections();
        }

        for(IWeighable weighable : weighables)
        {
            double allocation = weighable.getAllocation() / totalWeight;
            long numSelections = weighable.getNumSelections();
            double proportion = (double) numSelections / (double) totalSelections;

            if(proportion < allocation)
            {
                openWeighables.add(weighable);
            }
        }

        IWeighable selection = openWeighables.get(this.randomizer.nextInt(openWeighables.size()));
        selection.setNumSelections(selection.getNumSelections() + 1);
        return selection;
    }
}


回答7:

With a Item class that contains a getWeight() method (as in your question):

/**
 * Gets a random-weighted object.
 * @param items list with weighted items
 * @return a random item from items with a chance equal to its weight.
 * @assume total weight == 1
 */
public static Item getRandomWeighted(List<Item> items) {
    double value = Math.random(), weight = 0;

    for (Item item : items) {
        weight += item.getWeight();
        if (value < weight)
            return item;
    }

    return null; // Never will reach this point if assumption is true
}

With a Map and more generic method:

/**
 * Gets a random-weighted object.
 * @param balancedObjects the map with objects and their weights.
 * @return a random key-object from the map with a chance equal to its weight/totalWeight.
 * @throws IllegalArgumentException if total weight is not positive.
 */
public static <E> E getRandomWeighted(Map<E, ? extends Number> balancedObjects) throws IllegalArgumentException {
    double totalWeight = balancedObjects.values().stream().mapToInt(Number::intValue).sum(); // Java 8

    if (totalWeight <= 0)
        throw new IllegalArgumentException("Total weight must be positive.");

    double value = Math.random()*totalWeight, weight = 0;

    for (Entry<E, ? extends Number> e : balancedObjects.entrySet()) {
        weight += e.getValue().doubleValue();
        if (value < weight)
            return e.getKey();
    }

    return null; // Never will reach this point
}