Understanding filterM

2020-08-09 09:25发布

问题:

Consider

filterM (\x -> [True, False]) [1, 2, 3]

I just cannot understand the magic that Haskell does with this filterM use case. The source code for this function is listed below:

filterM          :: (Monad m) => (a -> m Bool) -> [a] -> m [a]
filterM _ []     =  return []
filterM p (x:xs) =  do
    flg <- p x
    ys  <- filterM p xs
    return (if flg then x:ys else ys)

With this use case, p should be the lambda function (\x -> [True, False]), and the first x should be 1. So what does flg <- p x return? What exactly is the value of flg for each recursion?

回答1:

The list monad [] models non-determinism: a list of values [a] represents a number of different possibilities for the value of a.

When you see a statement like flg <- p x in the list monad, flg will take on each value of p x in turn, i.e. True and then False in this case. The rest of the body of filterM is then executed twice, once for each value of flg.

To see how this happens in more detail, you need to understand the desugaring of do notation and the implementation of the (>>=) operator for the list monad.

do notation gets desugared line-by-line into calls to the (>>=) operator. For example the body of the non-empty filterM case turns into

p x >>= \flg -> (filterM p xs >>= \ys -> return (if flg then x:ys else ys))

This is completely mechanical as it's in essence just replacing flg <- before the expression with >>= \flg -> after the expression. In reality pattern-matching makes this a little more complicated, but not much.

Next is the actual implementation of (>>=), which is a member of the Monad type class and has a different implementation for each instance. For [], the type is:

(>>=) :: [a] -> (a -> [b]) -> [b]

and the implementation is something like

[] >>= f = []
(x:xs) >>= f = f x ++ (xs >>= f)

So the loop happens in the body of (>>=). This is all in a library, no compiler magic beyond the desugaring of the do notation.

An equivalent definition for (>>=) is

 xs >>= f = concat (map f xs)

which may also help you see what's happening.

The same thing then happens for the recursive call to filterM: for each value of flg, the recursive call is made and produces a list of results, and the final return statement is executed for each element ys in this list of result.

This "fan-out" on each recursive call leads to 2^3 = 8 elements in the final result of filterM (\x -> [True, False]) [1, 2, 3].



回答2:

This is pretty straightforward, after we've put it all down on paper (someone smart once gave this advice: don't try doing it all in your head, put it all on paper!):

filterM          :: (Monad m) => (a -> m Bool) -> [a] -> m [a]
filterM _ []     =  return []
filterM p (x:xs) =  do { flg <- p x
                       ; ys  <- filterM p xs
                       ; return (if flg then x:ys else ys) }

-- filterM (\x -> [True, False]) [1, 2, 3]
g [x,y,z] = filterM (\x -> [True, False]) (x:[y,z])
          = do {
                 flg <- (\x -> [True, False]) x
               ; ys  <- g [y,z]
               ; return ([x | flg] ++ ys) }
         = do {
                flg <- [True, False]               -- no `x` here!
              ; ys  <- do { flg2 <- (\x -> [True, False]) y
                          ; zs  <- g [z]
                          ; return ([y | flg2] ++ zs) }
              ; return ([x | flg] ++ ys) }
         = do {
                flg  <- [True, False]
              ; flg2 <- [True, False]
              ; zs   <- do { flg3 <- (\x -> [True, False]) z
                           ; s  <- g []
                           ; return ([z | flg3] ++ s) }
              ; return ([x | flg] ++ [y | flg] ++ zs) }
         = do {
                flg  <- [True, False]
              ; flg2 <- [True, False]
              ; flg3 <- [True, False]
              ; s    <- return []
              ; return ([x | flg] ++ [y | flg2] ++ [z | flg3] ++ s) }

The unnesting of the nested do blocks follows from the Monad laws.

Thus, in pseudocode:

    filterM (\x -> [True, False]) [1, 2, 3]
    =
      for flg in [True, False]:    -- x=1
          for flg2 in [True, False]:     -- y=2
              for flg3 in [True, False]:     -- z=3
                  yield ([1 | flg] ++ [2 | flg2] ++ [3 | flg3])