Calculating prime numbers in Scala: how does this

2019-01-31 14:54发布

问题:

So I've spent hours trying to work out exactly how this code produces prime numbers.

lazy val ps: Stream[Int] = 2 #:: Stream.from(3).filter(i =>
   ps.takeWhile{j => j * j <= i}.forall{ k => i % k > 0});

I've used a number of printlns etc, but nothings making it clearer.

This is what I think the code does:

/**
 * [2,3] 
 * 
 * takeWhile 2*2 <= 3 
 * takeWhile 2*2 <= 4 found match
 *      (4 % [2,3] > 1) return false.
 * takeWhile 2*2 <= 5 found match
 *      (5 % [2,3] > 1) return true 
 *          Add 5 to the list
 * takeWhile 2*2 <= 6 found match
 *      (6 % [2,3,5] > 1) return false
 * takeWhile 2*2 <= 7
 *      (7 % [2,3,5] > 1) return true
 *          Add 7 to the list
 */

But If I change j*j in the list to be 2*2 which I assumed would work exactly the same, it causes a stackoverflow error.

I'm obviously missing something fundamental here, and could really use someone explaining this to me like I was a five year old.

Any help would be greatly appreciated.

回答1:

I'm not sure that seeking a procedural/imperative explanation is the best way to gain understanding here. Streams come from functional programming and they're best understood from that perspective. The key aspects of the definition you've given are:

  1. It's lazy. Other than the first element in the stream, nothing is computed until you ask for it. If you never ask for the 5th prime, it will never be computed.

  2. It's recursive. The list of prime numbers is defined in terms of itself.

  3. It's infinite. Streams have the interesting property (because they're lazy) that they can represent a sequence with an infinite number of elements. Stream.from(3) is an example of this: it represents the list [3, 4, 5, ...].

Let's see if we can understand why your definition computes the sequence of prime numbers.

The definition starts out with 2 #:: .... This just says that the first number in the sequence is 2 - simple enough so far.

The next part defines the rest of the prime numbers. We can start with all the counting numbers starting at 3 (Stream.from(3)), but we obviously need to filter a bunch of these numbers out (i.e., all the composites). So let's consider each number i. If i is not a multiple of a lesser prime number, then i is prime. That is, i is prime if, for all primes k less than i, i % k > 0. In Scala, we could express this as

nums.filter(i => ps.takeWhile(k => k < i).forall(k => i % k > 0))

However, it isn't actually necessary to check all lesser prime numbers -- we really only need to check the prime numbers whose square is less than or equal to i (this is a fact from number theory*). So we could instead write

nums.filter(i => ps.takeWhile(k => k * k <= i).forall(k => i % k > 0))

So we've derived your definition.

Now, if you happened to try the first definition (with k < i), you would have found that it didn't work. Why not? It has to do with the fact that this is a recursive definition.

Suppose we're trying to decide what comes after 2 in the sequence. The definition tells us to first determine whether 3 belongs. To do so, we consider the list of primes up to the first one greater than or equal to 3 (takeWhile(k => k < i)). The first prime is 2, which is less than 3 -- so far so good. But we don't yet know the second prime, so we need to compute it. Fine, so we need to first see whether 3 belongs ... BOOM!

* It's pretty easy to see that if a number n is composite then the square of one of its factors must be less than or equal to n. If n is composite, then by definition n == a * b, where 1 < a <= b < n (we can guarantee a <= b just by labeling the two factors appropriately). From a <= b it follows that a^2 <= a * b, so it follows that a^2 <= n.



回答2:

Your explanations are mostly correct, you made only two mistakes:

takeWhile doesn't include the last checked element:

scala> List(1,2,3).takeWhile(_<2)
res1: List[Int] = List(1)

You assume that ps always contains only a two and a three but because Stream is lazy it is possible to add new elements to it. In fact each time a new prime is found it is added to ps and in the next step takeWhile will consider this new added element. Here, it is important to remember that the tail of a Stream is computed only when it is needed, thus takeWhile can't see it before forall is evaluated to true.

Keep these two things in mind and you should came up with this:

ps = [2]
i = 3
  takeWhile
    2*2 <= 3 -> false
  forall on []
    -> true
ps = [2,3]
i = 4
  takeWhile
    2*2 <= 4 -> true
    3*3 <= 4 -> false
  forall on [2]
    4%2 > 0 -> false
ps = [2,3]
i = 5
  takeWhile
    2*2 <= 5 -> true
    3*3 <= 5 -> false
  forall on [2]
    5%2 > 0 -> true
ps = [2,3,5]
i = 6
...

While these steps describe the behavior of the code, it is not fully correct because not only adding elements to the Stream is lazy but every operation on it. This means that when you call xs.takeWhile(f) not all values until the point when f is false are computed at once - they are computed when forall wants to see them (because it is the only function here that needs to look at all elements before it definitely can result to true, for false it can abort earlier). Here the computation order when laziness is considered everywhere (example only looking at 9):

ps = [2,3,5,7]
i = 9
  takeWhile on 2
    2*2 <= 9 -> true
  forall on 2
    9%2 > 0 -> true
  takeWhile on 3
    3*3 <= 9 -> true
  forall on 3
    9%3 > 0 -> false
ps = [2,3,5,7]
i = 10
...

Because forall is aborted when it evaluates to false, takeWhile doesn't calculate the remaining possible elements.



回答3:

That code is easier (for me, at least) to read with some variables renamed suggestively, as

lazy val ps: Stream[Int] = 2 #:: Stream.from(3).filter(i =>
   ps.takeWhile{p => p * p <= i}.forall{ p => i % p > 0});

This reads left-to-right quite naturally, as

primes are 2, and those numbers i from 3 up, that all of the primes p whose square does not exceed the i, do not divide i evenly (i.e. without some non-zero remainder).

In a true recursive fashion, to understand this definition as defining the ever increasing stream of primes, we assume that it is so, and from that assumption we see that no contradiction arises, i.e. the truth of the definition holds.

The only potential problem after that, is the timing of accessing the stream ps as it is being defined. As the first step, imagine we just have another stream of primes provided to us from somewhere, magically. Then, after seeing the truth of the definition, check that the timing of the access is okay, i.e. we never try to access the areas of ps before they are defined; that would make the definition stuck, unproductive.

I remember reading somewhere (don't recall where) something like the following -- a conversation between a student and a wizard,

  • student: which numbers are prime?
  • wizard: well, do you know what number is the first prime?
  • s: yes, it's 2.
  • w: okay (quickly writes down 2 on a piece of paper). And what about the next one?
  • s: well, next candidate is 3. we need to check whether it is divided by any prime whose square does not exceed it, but I don't yet know what the primes are!
  • w: don't worry, I'l give them to you. It's a magic I know; I'm a wizard after all.
  • s: okay, so what is the first prime number?
  • w: (glances over the piece of paper) 2.
  • s: great, so its square is already greater than 3... HEY, you've cheated! .....

Here's a pseudocode1 translation of your code, read partially right-to-left, with some variables again renamed for clarity (using p for "prime"):

ps = 2 : filter (\i-> all (\p->rem i p > 0) (takeWhile (\p->p^2 <= i) ps)) [3..]

which is also

ps = 2 : [i | i <- [3..], and [rem i p > 0 | p <- takeWhile (\p->p^2 <= i) ps]]

which is a bit more visually apparent, using list comprehensions. and checks that all entries in a list of Booleans are True (read | as "for", <- as "drawn from", , as "such that" and (\p-> ...) as "lambda of p").

So you see, ps is a lazy list of 2, and then of numbers i drawn from a stream [3,4,5,...] such that for all p drawn from ps such that p^2 <= i, it is true that i % p > 0. Which is actually an optimal trial division algorithm. :)

There's a subtlety here of course: the list ps is open-ended. We use it as it is being "fleshed-out" (that of course, because it is lazy). When ps are taken from ps, it could potentially be a case that we run past its end, in which case we'd have a non-terminating calculation on our hands (a "black hole"). It just so happens :) (and needs to ⁄ can be proved mathematically) that this is impossible with the above definition. So 2 is put into ps unconditionally, so there's something in it to begin with.

But if we try to "simplify",

bad = 2 : [i | i <- [3..], and [rem i p > 0 | p <- takeWhile (\p->p < i) bad]]

it stops working after producing just one number, 2: when considering 3 as the candidate, takeWhile (\p->p < 3) bad demands the next number in bad after 2, but there aren't yet any more numbers there. It "jumps ahead of itself".

This is "fixed" with

bad = 2 : [i | i <- [3..], and [rem i p > 0 | p <- [2..(i-1)] ]]

but that is a much much slower trial division algorithm, very far from the optimal one.

--

1 (Haskell actually, it's just easier for me that way :) )