Random weighted selection in Java

I want to choose a random item from a set, but the chance of choosing any item should be proportional to the associated weight

Example inputs:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1

So, if I have 4 possible items, the chance of getting any one item without weights would be 1 in 4.

In this case, a user should be 10 times more likely to get the sword of misery than the triple-edged sword.

How do I make a weighted random selection in Java?

43992 次浏览

You will not find a framework for this kind of problem, as the requested functionality is nothing more then a simple function. Do something like this:

interface Item {
double getWeight();
}


class RandomItemChooser {
public Item chooseOnWeight(List<Item> items) {
double completeWeight = 0.0;
for (Item item : items)
completeWeight += item.getWeight();
double r = Math.random() * completeWeight;
double countWeight = 0.0;
for (Item item : items) {
countWeight += item.getWeight();
if (countWeight >= r)
return item;
}
throw new RuntimeException("Should never be shown.");
}
}

I would use a NavigableMap

public class RandomCollection<E> {
private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
private final Random random;
private double total = 0;


public RandomCollection() {
this(new Random());
}


public RandomCollection(Random random) {
this.random = random;
}


public RandomCollection<E> add(double weight, E result) {
if (weight <= 0) return this;
total += weight;
map.put(total, result);
return this;
}


public E next() {
double value = random.nextDouble() * total;
return map.higherEntry(value).getValue();
}
}

Say I have a list of animals dog, cat, horse with probabilities as 40%, 35%, 25% respectively

RandomCollection<String> rc = new RandomCollection<>()
.add(40, "dog").add(35, "cat").add(25, "horse");


for (int i = 0; i < 10; i++) {
System.out.println(rc.next());
}

There is now a class for this in Apache Commons: EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution<>(itemWeights).sample();

where itemWeights is a List<Pair<Item, Double>>, like (assuming Item interface in Arne's answer):

final List<Pair<Item, Double>> itemWeights = Collections.newArrayList();
for (Item i: itemSet) {
itemWeights.add(new Pair(i, i.getWeight()));
}

or in Java 8:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());

Note: Pair here needs to be org.apache.commons.math3.util.Pair, not org.apache.commons.lang3.tuple.Pair.

Use an alias method

If you're gonna roll a lot of times (as in a game), you should use an alias method.

The code below is rather long implementation of such an alias method, indeed. But this is because of the initialization part. The retrieval of elements is very fast (see the next and the applyAsInt methods they don't loop).

Usage

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;


Random random = new Random();


RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);

Implementation

This implementation:

  • uses Java 8;
  • is designed to be as fast as possible (well, at least, I tried to do so using micro-benchmarking);
  • is totally thread-safe (keep one Random in each thread for maximum performance, use ThreadLocalRandom?);
  • fetches elements in O(1), unlike what you mostly find on the internet or on StackOverflow, where naive implementations run in O(n) or O(log(n));
  • keeps the items independant from their weight, so an item can be assigned various weights in different contexts.

Anyways, here's the code. (Note that I maintain an up to date version of this class.)

import static java.util.Objects.requireNonNull;


import java.util.*;
import java.util.function.*;


public final class RandomSelector<T> {


public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
throws IllegalArgumentException {
requireNonNull(elements, "elements must not be null");
requireNonNull(weighter, "weighter must not be null");
if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }


// Array is faster than anything. Use that.
int size = elements.size();
T[] elementArray = elements.toArray((T[]) new Object[size]);


double totalWeight = 0d;
double[] discreteProbabilities = new double[size];


// Retrieve the probabilities
for (int i = 0; i < size; i++) {
double weight = weighter.applyAsDouble(elementArray[i]);
if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
discreteProbabilities[i] = weight;
totalWeight += weight;
}
if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }


// Normalize the probabilities
for (int i = 0; i < size; i++) {
discreteProbabilities[i] /= totalWeight;
}
return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
}


private final T[] elements;
private final ToIntFunction<Random> selection;


private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
this.elements = elements;
this.selection = selection;
}


public T next(Random random) {
return elements[selection.applyAsInt(random)];
}


private static class RandomWeightedSelection implements ToIntFunction<Random> {
// Alias method implementation O(1)
// using Vose's algorithm to initialize O(n)


private final double[] probabilities;
private final int[] alias;


RandomWeightedSelection(double[] probabilities) {
int size = probabilities.length;


double average = 1.0d / size;
int[] small = new int[size];
int smallSize = 0;
int[] large = new int[size];
int largeSize = 0;


// Describe a column as either small (below average) or large (above average).
for (int i = 0; i < size; i++) {
if (probabilities[i] < average) {
small[smallSize++] = i;
} else {
large[largeSize++] = i;
}
}


// For each column, saturate a small probability to average with a large probability.
while (largeSize != 0 && smallSize != 0) {
int less = small[--smallSize];
int more = large[--largeSize];
probabilities[less] = probabilities[less] * size;
alias[less] = more;
probabilities[more] += probabilities[less] - average;
if (probabilities[more] < average) {
small[smallSize++] = more;
} else {
large[largeSize++] = more;
}
}


// Flush unused columns.
while (smallSize != 0) {
probabilities[small[--smallSize]] = 1.0d;
}
while (largeSize != 0) {
probabilities[large[--largeSize]] = 1.0d;
}
}


@Override public int applyAsInt(Random random) {
// Call random once to decide which column will be used.
int column = random.nextInt(probabilities.length);


// Call random a second time to decide which will be used: the column or the alias.
if (random.nextDouble() < probabilities[column]) {
return column;
} else {
return alias[column];
}
}
}
}
public class RandomCollection<E> {
private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
private double total = 0;


public void add(double weight, E result) {
if (weight <= 0 || map.containsValue(result))
return;
total += weight;
map.put(total, result);
}


public E next() {
double value = ThreadLocalRandom.current().nextDouble() * total;
return map.ceilingEntry(value).getValue();
}
}

139

There is a straightforward algorithm for picking an item at random, where items have individual weights:

  1. calculate the sum of all the weights

  2. pick a random number that is 0 or greater and is less than the sum of the weights

  3. go through the items one at a time, subtracting their weight from your random number until you get the item where the random number is less than that item's weight

A simple (even naive?), but (as I believe) straightforward method:

/**
* Draws an integer between a given range (excluding the upper limit).
* <p>
* Simulates Python's randint method.
*
* @param min: the smallest value to be drawed.
* @param max: the biggest value to be drawed.
* @return The value drawn.
*/
public static int randomInt(int min, int max)
{return (int) (min + Math.random()*max);}


/**
* Tests wether a given matrix has all its inner vectors
* has the same passed and expected lenght.
* @param matrix: the matrix from which the vectors length will be measured.
* @param expectedLenght: the length each vector should have.
* @return false if at least one vector has a different length.
*/
public static boolean haveAllVectorsEqualLength(int[][] matrix, int expectedLenght){
for(int[] vector: matrix){if (vector.length != expectedLenght) {return false;}}
return true;
}


/**
* Draws an integer between a given range
* by weighted values.
*
* @param ticketBlock: matrix with limits and weights for the drawing. All its
* vectors should have lenght two. The weights, instead of percentages, should be
* measured as integers, according to how rare each one should be draw, the rarest
* receiving the smallest value.
* @return The value drawn.
*/
public static int weightedRandomInt(int[][] ticketBlock) throws RuntimeException {
boolean theVectorsHaventAllLengthTwo = !(haveAllVectorsEqualLength(ticketBlock, 2));
if (theVectorsHaventAllLengthTwo)
{throw new RuntimeException("The given matrix has, at least, one vector with length lower or higher than two.");}
// Need to test for duplicates or null values in ticketBlock!
    

// Raffle urn building:
int raffleUrnSize = 0, urnIndex = 0, blockIndex = 0, repetitionCount = 0;
for(int[] ticket: ticketBlock){raffleUrnSize += ticket[1];}
int[] raffleUrn = new int[raffleUrnSize];
    

// Raffle urn filling:
while (urnIndex < raffleUrn.length){
do {
raffleUrn[urnIndex] = ticketBlock[blockIndex][0];
urnIndex++; repetitionCount++;
} while (repetitionCount < ticketBlock[blockIndex][1]);
repetitionCount = 0; blockIndex++;
}
    

return raffleUrn[randomInt(0, raffleUrn.length)];
}