In the #haskell IRC channel someone asked
Is there a succinct way to define a list where the nth entry is the sum of the squares of all entries before?
I thought this sounded like a fun puzzle, and defining infinite lists recursively is one of those things I really need to practise. So I fired up GHCi and started playing around with recursive definitions. Eventually, I managed to get to
λ> let xs = 1 : [sum (map (^2) ys) | ys <- inits xs, not (null ys)]
which seems to produce the correct results:
λ> take 9 xs
[1,1,2,6,42,1806,3263442,10650056950806,113423713055421844361000442]
Unfortunately, I have no idea how the code I wrote works. Is it possible to explain what happens when this code is executed in a way that an intermediate Haskell user would understand?
It comes down to lazy evaluation. Let's go with augustss's definition, since it is just slightly simpler, but call it big
instead of xs
, since that identifier is commonly used in utilities.
Haskell only evaluates code as immediately necessary. If something isn't necessary, there is a stub there, basically a pointer to a function closure that can calculate the value if needed.
Let's say I want to evaluate big !! 4
. The definition of !!
is something like this:
[] !! _ = error "Prelude.(!!): index too large"
(x:_) !! 0 = x
(_:xs) !! n = xs !! (n-1)
The definition of big
is
big = 1 : [sum (map (^2) ys) | ys <- tail (inits big)]
So in evaluating the index access, the first thing that happens is that the correct function variant must be chosen. The list data type has two constructors, []
and first : rest
. The call is big !! 4
, and the first branch of !!
just checks whether the list is []
. Since the list explicitly starts with 1 : stub1
, the answer is no, and the branch is skipped.
The second branch wants to know whether the first : rest
form was chosen. The answer is yes, with first
being 1
and rest
being that big comprehension (stub1
), its value irrelevant. But the second argument is not 0
, so this branch is skipped as well.
The third branch also matches against first : last
, but accepts anything for the second argument, so it applies. It ignores first
, binds xs
to the unevaluated comprehension stub1
, and n
to 4
. It then recursively calls itself with the first argument being the comprehension and the second 3
. (Technically, that's (4-1)
and isn't yet evaluated, but as a simplification we will assume it is.)
The recursive call again has to evaluate its branches. The first branch checks whether the first argument is empty. Since the argument so far is an unevaluated stub, it will need to be evaluated. But only far enough to decide whether the branch is empty. So let's start evaluating the comprehension:
stub1 = [sum (map (^2) ys) | ys <- tail (inits big)]
The first thing we need is ys
. It's set to tail (inits big)
. tail
is simple:
tail [] = []
tail (_:xs) = xs
inits
is rather complex to implement, but the important thing is that it generates its result list lazily, i.e. if you give it (x:unevaluated)
, it will generate []
and [x]
before evaluating the rest of the list. In other words, if you don't look beyond those, it won't ever evaluate the rest.
So, so far big
is known to be (1 : stub1)
, so inits
returns [] : stub2
. tail
matches against this, chooses its second branch, and returns stub2
. stub2
is the list of inits of big
after the omnipresent empty list, and it hasn't yet been generated.
The list comprehension then tries to give ys
the value of the first element of stub2
, so it has to be evaluated. The second result of inits
is still known, it's [1]
. ys
gets that value. At this point, then, big
is known to be 1 : stub3 : stub4
, where stub3 = sum (map (^2) [1])
and stub4
is the list comprehension after the first iteration.
Since big
is now evaluated further, so is stub1
. It is now known to be stub3 : stub4
, and we can finally advance in !!
. The first branch doesn't apply, since the list isn't empty. The second branch doesn't apply because 3 /= 0
. The third branch applies, binding xs
to stub4
and n
to 3
. The recursive call is stub4 !! 2
.
We need to evaluate a bit of stub4
. This means we enter the next iteration of the comprehension. We need the third element of inits big
. Since big
is by now known to be 1 : stub3 : stub4
, the third element can be calculated without further evaluation as [1, stub3]
. ys
is bound to this value, and stub4
evaluates to stub5 : stub6
, where stub5 = sum (map (^2) [1, stub3])
and stub6
is the comprehension after the first two iterations. With stub4
evaluated, we now know that big = 1 : stub3 : stub5 : stub6
.
So stub4
still doesn't match the first branch of !!
(nothing ever will, since we're dealing with an infinite list). 2
still doesn't match the second branch. We have another recursive call, and then another, following the same pattern as we had so far. When the index finally reaches 0, we have:
big = 1 : stub3 : stub5 : stub7 : stub9 : stub10
stub3 = sum (map (^2) [1])
stub5 = sum (map (^2) [1, stub3])
stub7 = sum (map (^2) [1, stub3, stub5])
stub9 = sum (map (^2) [1, stub3, stub5, stub7])
stub10 = whatever remains of the list comprehension
Our current call is (stub9 : stub10) !! 0
, which finally matches the second branch. x
is bound to stub9
and returned.
And only now, if you actually try to print or otherwise process x
, are all of these stubs finally evaluated to an actual number.
OK, I'll try.
(I'm not sure what "intermediate" level you're looking for, so I'll explain this to myself in the hope that it's not too "sub-intermediate".)
sum (map (^2) ys)
is easy: the sum of squares of a list.
The generator is also easy: y
takes on all the non-empty initial sequences of xs
, i.e. (abusing notation a little) y <- [take 1 xs, take 2 xs, take 3 xs,...]
.
(I'll keep the take
notation in the following, as I think it's pretty clear. It's most likely not what happens internally in your shiny Haskell machine.)
The only tricky thing is combining them, as xs
is the value we're defining.
This is not a huge problem, because we know the first element of xs
- it's 1
.
It's not much, but it's everything that's needed to get the ball rolling with take 1 xs
.
Handwaving a bit more, xs
is
1 : (sum (map (^2) (take 1 xs))) : (sum (map (^2) (take 2 xs))) : ...
that is (because we know that the first element is 1
):
xs = 1 : (sum (map (^2) [1])) : (sum (map (^2) (take 2 xs))) : ...
xs = 1 : 1 : (sum (map (^2) (take 2 xs))) : (sum (map (^2) (take 3 xs))) : ...
where we have the second element, and we can continue:
xs = 1 : 1 : (sum (map (^2) [1,1])) : (sum (map (^2) (take 3 xs))) : ...
xs = 1 : 1 : 2 : (sum (map (^2) (take 3 xs))) : (sum (map (^2) (take 4 xs))) : ...
xs = 1 : 1 : 2 : (sum (map (^2) [1,1,2])) : (sum (map (^2) (take 4 xs))) : ...
and so on.
The reason this works at all is that every element in the list only depends on the previous elements - you can always rely on the past to tell you what's happened; the future is less reliable.
Shaking your code a bit until it settles into a more comprehensible form, we get
xs = 1 : [sum (map (^2) ys) | ys <- inits xs, not (null ys)]
= 1 : (map (sum . map (^2)) . map (`take` xs)) [1..]
= 1 : map (sum . map (^2) . (`take` xs)) [1..]
= 1 : scanl1 (\a b-> a+b^2) xs
= x1 : xs1
where { x1 = 1;
xs1 = scanl1 g xs
= scanl g x1 xs1; -- by scanl1 definition
g a x = a+x^2 }
scanl
works for non-empty lists as
scanl g a xs = a : case xs of (h:t) -> scanl g (g a h) t
so xs1 = scanl g a xs1
will first put the currently known accumulating value at its output's head (xs1 = (a:_)
), and only then will it read that output, so this definition is productive. We also see that h = a
, so g a h = g a a = a+a^2 = a*(a+1)
and we can code this stream purely iteratively, as
xs = 1 : iterate (\a -> a*(a+1)) 1