Consider
filterM (\x -> [True, False]) [1, 2, 3]
I just cannot understand the magic that Haskell does with this filterM
use case. The source code for this function is listed below:
filterM :: (Monad m) => (a -> m Bool) -> [a] -> m [a]
filterM _ [] = return []
filterM p (x:xs) = do
flg <- p x
ys <- filterM p xs
return (if flg then x:ys else ys)
With this use case, p
should be the lambda function (\x -> [True, False])
, and the first x
should be 1
. So what does flg <- p x
return? What exactly is the value of flg
for each recursion?
The list monad []
models non-determinism: a list of values [a]
represents a number of different possibilities for the value of a
.
When you see a statement like flg <- p x
in the list monad, flg
will take on each value of p x
in turn, i.e. True
and then False
in this case. The rest of the body of filterM
is then executed twice, once for each value of flg
.
To see how this happens in more detail, you need to understand the desugaring of do
notation and the implementation of the (>>=)
operator for the list monad.
do
notation gets desugared line-by-line into calls to the (>>=)
operator. For example the body of the non-empty filterM
case turns into
p x >>= \flg -> (filterM p xs >>= \ys -> return (if flg then x:ys else ys))
This is completely mechanical as it's in essence just replacing flg <-
before the expression with >>= \flg ->
after the expression. In reality pattern-matching makes this a little more complicated, but not much.
Next is the actual implementation of (>>=)
, which is a member of the Monad
type class and has a different implementation for each instance. For []
, the type is:
(>>=) :: [a] -> (a -> [b]) -> [b]
and the implementation is something like
[] >>= f = []
(x:xs) >>= f = f x ++ (xs >>= f)
So the loop happens in the body of (>>=)
. This is all in a library, no compiler magic beyond the desugaring of the do
notation.
An equivalent definition for (>>=)
is
xs >>= f = concat (map f xs)
which may also help you see what's happening.
The same thing then happens for the recursive call to filterM
: for each value of flg
, the recursive call is made and produces a list of results, and the final return
statement is executed for each element ys
in this list of result.
This "fan-out" on each recursive call leads to 2^3 = 8
elements in the final result of filterM (\x -> [True, False]) [1, 2, 3]
.
This is pretty straightforward, after we've put it all down on paper (someone smart once gave this advice: don't try doing it all in your head, put it all on paper!):
filterM :: (Monad m) => (a -> m Bool) -> [a] -> m [a]
filterM _ [] = return []
filterM p (x:xs) = do { flg <- p x
; ys <- filterM p xs
; return (if flg then x:ys else ys) }
-- filterM (\x -> [True, False]) [1, 2, 3]
g [x,y,z] = filterM (\x -> [True, False]) (x:[y,z])
= do {
flg <- (\x -> [True, False]) x
; ys <- g [y,z]
; return ([x | flg] ++ ys) }
= do {
flg <- [True, False] -- no `x` here!
; ys <- do { flg2 <- (\x -> [True, False]) y
; zs <- g [z]
; return ([y | flg2] ++ zs) }
; return ([x | flg] ++ ys) }
= do {
flg <- [True, False]
; flg2 <- [True, False]
; zs <- do { flg3 <- (\x -> [True, False]) z
; s <- g []
; return ([z | flg3] ++ s) }
; return ([x | flg] ++ [y | flg] ++ zs) }
= do {
flg <- [True, False]
; flg2 <- [True, False]
; flg3 <- [True, False]
; s <- return []
; return ([x | flg] ++ [y | flg2] ++ [z | flg3] ++ s) }
The unnesting of the nested do
blocks follows from the Monad laws.
Thus, in pseudocode:
filterM (\x -> [True, False]) [1, 2, 3]
=
for flg in [True, False]: -- x=1
for flg2 in [True, False]: -- y=2
for flg3 in [True, False]: -- z=3
yield ([1 | flg] ++ [2 | flg2] ++ [3 | flg3])