How to obtain the nth random “nextInt” value?

2019-06-28 01:21发布

When using class java.util.Random, how can one get the value obtained from calling the method nextInt() N times, but in a much more efficient way (in O(1) specifically)?

For example, if I construct a Random object with a particular seed value, and I want to get the 100,000th "nextInt() value" (that is, the value obtained after calling the method nextInt() 100,000 times) in a fast way, could I do it?

Assume, for simplicity, version 1.7.06 of the JDK, since it may be required to know the exact values of some private fields in class Random. And speaking of, I found the following fields to be relevant in the calculation of a random value:

private static final long multiplier = 0x5DEECE66DL;
private static final long addend = 0xBL;
private static final long mask = (1L << 48) - 1;

After exploring a bit about randomness, I found that random values are obtained using a Linear congruential generator. The actual method that executes the algorithm is method next(int):

protected int next(int bits) {
    long oldseed, nextseed;
    AtomicLong seed = this.seed;
    do {
        oldseed = seed.get();
        nextseed = (oldseed * multiplier + addend) & mask;
    } while (!seed.compareAndSet(oldseed, nextseed));
    return (int)(nextseed >>> (48 - bits));
}

The relevant line for the algorithm is the one that obtains the next seed value:

nextseed = (oldseed * multiplier + addend) & mask;

So, to be more specific, is there a way that I can generalize this formula to obtain the "nth nextseed" value? I'm assuming here that after having that, I can then simply obtain the nth int value by letting the variable "bits" be 32 (the method nextInt() simply calls next(32) and returns the result).

Thanks in advance

PS: Perhaps this is a question more suitable for mathexchange?

标签: java random
2条回答
爷的心禁止访问
2楼-- · 2019-06-28 01:35

I've accepted the answer from Daniel Fischer as it is correct and gives the general solution. Using Daniel's answer, here is a concrete example with java code that shows a basic implementation of the formula (I used class BigInteger extensively so it may not be optimal, but I confirmed a significant speedup over the rudimentary way of actually calling the method nextInt() N times):

import java.math.BigInteger;
import java.util.Random;


public class RandomNthNextInt {

    // copied from java.util.Random =========================
    private static final long   multiplier  = 0x5DEECE66DL;
    private static final long   addend      = 0xBL;
    private static final long   mask        = (1L << 48) - 1;


    private static long initialScramble(long seed) {

        return (seed ^ multiplier) & mask;
    }

    private static int getNextInt(long nextSeed) {

        return (int)(nextSeed >>> (48 - 32));
    }
    // ======================================================

    private static final BigInteger mod = BigInteger.valueOf(mask + 1L);
    private static final BigInteger inv = BigInteger.valueOf((multiplier - 1L) / 4L).modInverse(mod);


    /**
     * Returns the value obtained after calling the method {@link Random#nextInt()} {@code n} times from a
     * {@link Random} object initialized with the {@code seed} value.
     * <p>
     * This method does not actually create any {@code Random} instance, instead it applies a direct formula which
     * calculates the expected value in a more efficient way (close to O(log N)).
     * 
     * @param seed
     *            The initial seed value of the supposed {@code Random} object
     * @param n
     *            The index (starting at 1) of the "nextInt() value"
     * @return the nth "nextInt() value" of a {@code Random} object initialized with the given seed value
     * @throws IllegalArgumentException
     *             If {@code n} is not positive
     */
    public static long getNthNextInt(long seed, long n) {

        if (n < 1L) {
            throw new IllegalArgumentException("n must be positive");
        }

        final BigInteger seedZero = BigInteger.valueOf(initialScramble(seed));
        final BigInteger nthSeed = calculateNthSeed(seedZero, n);

        return getNextInt(nthSeed.longValue());
    }

    private static BigInteger calculateNthSeed(BigInteger seed0, long n) {

        final BigInteger largeM = calculateLargeM(n);
        final BigInteger largeMmin1div4 = largeM.subtract(BigInteger.ONE).divide(BigInteger.valueOf(4L));

        return seed0.multiply(largeM).add(largeMmin1div4.multiply(inv).multiply(BigInteger.valueOf(addend))).mod(mod);
    }

    private static BigInteger calculateLargeM(long n) {

        return BigInteger.valueOf(multiplier).modPow(BigInteger.valueOf(n), BigInteger.valueOf(1L << 50));
    }

    // =========================== Testing stuff ======================================

    public static void main(String[] args) {

        final long n = 100000L; // change this to test other values
        final long seed = 1L; // change this to test other values

        System.out.println(n + "th nextInt (formula) = " + getNthNextInt(seed, n));
        System.out.println(n + "th nextInt (slow)    = " + getNthNextIntSlow(seed, n));
    }

    private static int getNthNextIntSlow(long seed, long n) {

        if (n < 1L) {
            throw new IllegalArgumentException("n must be positive");
        }

        final Random rand = new Random(seed);
        for (long eL = 0; eL < (n - 1); eL++) {
            rand.nextInt();
        }
        return rand.nextInt();
    }
}

NOTE: Notice the method initialScramble(long), which is used to get the first seed value. This is the behavior of class Random when initializing an instance with a specific seed.

查看更多
叛逆
3楼-- · 2019-06-28 01:38

You can do it in O(log N) time. Starting with s(0), if we ignore the modulus (248) for the moment, we can see (using m and a as shorthand for multiplier and addend) that

s(1) = s(0) * m + a
s(2) = s(1) * m + a = s(0) * m² + (m + 1) * a
s(3) = s(2) * m + a = s(0) * m³ + (m² + m + 1) * a
...
s(N) = s(0) * m^N + (m^(N-1) + ... + m + 1) * a

Now, m^N (mod 2^48) can easily be computed in O(log N) steps by modular exponentiation by repeated squaring.

The other part is a bit more complicated. Ignoring the modulus again for the moment, the geometric sum is

(m^N - 1) / (m - 1)

What makes computing this modulo 2^48 a bit nontrivial is that m - 1 is not coprime to the modulus. However, since

m = 0x5DEECE66DL

the greatest common divisor of m-1 and the modulus is 4, and (m-1)/4 has a modular inverse inv modulo 2^48. Let

c = (m^N - 1) (mod 4*2^48)

Then

(c / 4) * inv ≡ (m^N - 1) / (m - 1) (mod 2^48)

So

  • compute M ≡ m^N (mod 2^50)
  • compute inv

to obtain

s(N) ≡ s(0)*M + ((M - 1)/4)*inv*a (mod 2^48)
查看更多
登录 后发表回答