How does this list comprehension over the inits of

2020-08-24 16:24发布

问题:

In the #haskell IRC channel someone asked

Is there a succinct way to define a list where the nth entry is the sum of the squares of all entries before?

I thought this sounded like a fun puzzle, and defining infinite lists recursively is one of those things I really need to practise. So I fired up GHCi and started playing around with recursive definitions. Eventually, I managed to get to

λ> let xs = 1 : [sum (map (^2) ys) | ys <- inits xs, not (null ys)]

which seems to produce the correct results:

λ> take 9 xs
[1,1,2,6,42,1806,3263442,10650056950806,113423713055421844361000442]

Unfortunately, I have no idea how the code I wrote works. Is it possible to explain what happens when this code is executed in a way that an intermediate Haskell user would understand?

回答1:

It comes down to lazy evaluation. Let's go with augustss's definition, since it is just slightly simpler, but call it big instead of xs, since that identifier is commonly used in utilities.

Haskell only evaluates code as immediately necessary. If something isn't necessary, there is a stub there, basically a pointer to a function closure that can calculate the value if needed.

Let's say I want to evaluate big !! 4. The definition of !! is something like this:

[] !! _ = error "Prelude.(!!): index too large"
(x:_) !! 0 = x
(_:xs) !! n = xs !! (n-1)

The definition of big is

big = 1 : [sum (map (^2) ys) | ys <- tail (inits big)]

So in evaluating the index access, the first thing that happens is that the correct function variant must be chosen. The list data type has two constructors, [] and first : rest. The call is big !! 4, and the first branch of !! just checks whether the list is []. Since the list explicitly starts with 1 : stub1, the answer is no, and the branch is skipped.

The second branch wants to know whether the first : rest form was chosen. The answer is yes, with first being 1 and rest being that big comprehension (stub1), its value irrelevant. But the second argument is not 0, so this branch is skipped as well.

The third branch also matches against first : last, but accepts anything for the second argument, so it applies. It ignores first, binds xs to the unevaluated comprehension stub1, and n to 4. It then recursively calls itself with the first argument being the comprehension and the second 3. (Technically, that's (4-1) and isn't yet evaluated, but as a simplification we will assume it is.)

The recursive call again has to evaluate its branches. The first branch checks whether the first argument is empty. Since the argument so far is an unevaluated stub, it will need to be evaluated. But only far enough to decide whether the branch is empty. So let's start evaluating the comprehension:

stub1 = [sum (map (^2) ys) | ys <- tail (inits big)]

The first thing we need is ys. It's set to tail (inits big). tail is simple:

tail [] = []
tail (_:xs) = xs

inits is rather complex to implement, but the important thing is that it generates its result list lazily, i.e. if you give it (x:unevaluated), it will generate [] and [x] before evaluating the rest of the list. In other words, if you don't look beyond those, it won't ever evaluate the rest.

So, so far big is known to be (1 : stub1), so inits returns [] : stub2. tail matches against this, chooses its second branch, and returns stub2. stub2 is the list of inits of big after the omnipresent empty list, and it hasn't yet been generated.

The list comprehension then tries to give ys the value of the first element of stub2, so it has to be evaluated. The second result of inits is still known, it's [1]. ys gets that value. At this point, then, big is known to be 1 : stub3 : stub4, where stub3 = sum (map (^2) [1]) and stub4 is the list comprehension after the first iteration.

Since big is now evaluated further, so is stub1. It is now known to be stub3 : stub4, and we can finally advance in !!. The first branch doesn't apply, since the list isn't empty. The second branch doesn't apply because 3 /= 0. The third branch applies, binding xs to stub4 and n to 3. The recursive call is stub4 !! 2.

We need to evaluate a bit of stub4. This means we enter the next iteration of the comprehension. We need the third element of inits big. Since big is by now known to be 1 : stub3 : stub4, the third element can be calculated without further evaluation as [1, stub3]. ys is bound to this value, and stub4 evaluates to stub5 : stub6, where stub5 = sum (map (^2) [1, stub3]) and stub6 is the comprehension after the first two iterations. With stub4 evaluated, we now know that big = 1 : stub3 : stub5 : stub6.

So stub4 still doesn't match the first branch of !! (nothing ever will, since we're dealing with an infinite list). 2 still doesn't match the second branch. We have another recursive call, and then another, following the same pattern as we had so far. When the index finally reaches 0, we have:

big = 1 : stub3 : stub5 : stub7 : stub9 : stub10
stub3 = sum (map (^2) [1])
stub5 = sum (map (^2) [1, stub3])
stub7 = sum (map (^2) [1, stub3, stub5])
stub9 = sum (map (^2) [1, stub3, stub5, stub7])
stub10 = whatever remains of the list comprehension

Our current call is (stub9 : stub10) !! 0, which finally matches the second branch. x is bound to stub9 and returned.

And only now, if you actually try to print or otherwise process x, are all of these stubs finally evaluated to an actual number.



回答2:

OK, I'll try.

(I'm not sure what "intermediate" level you're looking for, so I'll explain this to myself in the hope that it's not too "sub-intermediate".)

sum (map (^2) ys) is easy: the sum of squares of a list.

The generator is also easy: y takes on all the non-empty initial sequences of xs, i.e. (abusing notation a little) y <- [take 1 xs, take 2 xs, take 3 xs,...].

(I'll keep the take notation in the following, as I think it's pretty clear. It's most likely not what happens internally in your shiny Haskell machine.)

The only tricky thing is combining them, as xs is the value we're defining.
This is not a huge problem, because we know the first element of xs - it's 1.
It's not much, but it's everything that's needed to get the ball rolling with take 1 xs.

Handwaving a bit more, xs is

1 : (sum (map (^2) (take 1 xs))) : (sum (map (^2) (take 2 xs))) : ...

that is (because we know that the first element is 1):

xs = 1 : (sum (map (^2) [1])) : (sum (map (^2) (take 2 xs))) : ...

xs = 1 : 1 : (sum (map (^2) (take 2 xs))) : (sum (map (^2) (take 3 xs))) : ...

where we have the second element, and we can continue:

xs = 1 : 1 : (sum (map (^2) [1,1])) : (sum (map (^2) (take 3 xs))) : ...

xs = 1 : 1 : 2 : (sum (map (^2) (take 3 xs))) : (sum (map (^2) (take 4 xs))) : ...

xs = 1 : 1 : 2 : (sum (map (^2) [1,1,2])) : (sum (map (^2) (take 4 xs))) : ...

and so on.

The reason this works at all is that every element in the list only depends on the previous elements - you can always rely on the past to tell you what's happened; the future is less reliable.



回答3:

Shaking your code a bit until it settles into a more comprehensible form, we get

xs = 1 : [sum (map (^2) ys) | ys <- inits xs, not (null ys)]
   = 1 : (map (sum . map (^2)) . map (`take` xs)) [1..]
   = 1 : map (sum . map (^2) . (`take` xs)) [1..]
   = 1 : scanl1 (\a b-> a+b^2) xs
   = x1 : xs1
              where { x1  = 1; 
                      xs1 = scanl1 g xs 
                          = scanl g x1 xs1;   -- by scanl1 definition
                      g a x = a+x^2 }

scanl works for non-empty lists as

scanl g a xs = a : case xs of (h:t) -> scanl g (g a h) t

so xs1 = scanl g a xs1 will first put the currently known accumulating value at its output's head (xs1 = (a:_)), and only then will it read that output, so this definition is productive. We also see that h = a, so g a h = g a a = a+a^2 = a*(a+1) and we can code this stream purely iteratively, as

xs = 1 : iterate (\a -> a*(a+1)) 1