How does this State monad code works?

2019-05-06 02:31发布

问题:

This code is from this article

I've been able to follow it until this part.

module Test where

type State = Int

data ST a = S (State -> (a, State))

apply        :: ST a -> State -> (a,State)
apply (S f) x = f x

fresh =  S (\n -> (n, n+1))

instance Monad ST where
    -- return :: a -> ST a
    return x   = S (\s -> (x,s))

    -- (>>=)  :: ST a -> (a -> ST b) -> ST b
    st >>= f   = S (\s -> let (x,s') = apply st s in apply (f x) s')

data Tree a = Leaf a | Node (Tree a) (Tree a) deriving (Show)


mlabel  :: Tree a -> ST (Tree (a,Int))
-- THIS IS THE PART I DON'T UNDERSTAND:
mlabel (Leaf x) = do n <- fresh
                     return (Leaf (x,n))
mlabel (Node l r) =  do l' <- mlabel l
                        r' <- mlabel r
                        return (Node l' r')

label t = fst (apply (mlabel t) 0)

tree = Node (Node (Leaf 'a') (Leaf 'b')) (Leaf 'c')

And label tree produces:

Node (Node (Leaf ('a',0)) (Leaf ('b',1))) (Leaf ('c',2))

I can see that >>= operator is the tool to 'chain' functions that return monads (or something like that).

And while I think I understand this code, I don't understand how this particular code works.

Specifically do n <- fresh. We haven't passed any argument to fresh, right? What does n <- fresh produces in that case? Absolutely don't understand that. Maybe it has something to do with currying?

回答1:

With the monadic "pipelining" inlined, your code becomes

fresh state = (state, state + 1)

mlabel (Leaf x) state =                   --  do
  let (n, state') = fresh state           --    n <- fresh
  in  (Leaf (x,n), state')                --    return (Leaf (x,n))

mlabel (Node l r) state =                 -- do
  let (l', state') = mlabel l state       --    l' <- mlabel l
  in let (r', state'') = mlabel r state'  --    r' <- mlabel r
     in  (Node l' r', state'')            --    return (Node l' r') 

main = let (result, state') = mlabel tree 0  
       in  print result                         

{- Or with arrows,

mlabel (Leaf x)   = Leaf . (x ,)  &&&  (+ 1)
mlabel (Node l r) = mlabel l >>> second (mlabel r)
                              >>> (\(a,(b,c)) -> (Node a b,c))
main              = mlabel tree >>> fst >>> print  $ 0
-}

Or in an imperative pseudocode:

def state = unassigned

def fresh ():
    tmp = state 
    state := state + 1     -- `fresh` alters the global var `state`
    return tmp             -- each time it is called

def mlabel (Leaf x):       -- a language with pattern matching
    n = fresh ()           -- global `state` is altered!
    return (Leaf (x,n))  

def mlabel (Node l r):
    l' = mlabel l          -- affects the global
    r' = mlabel r          --    assignable variable
    return (Node l' r')    --    `state`

def main:
    state := 0             -- set to 0 before the calculation!
    result = mlabel tree
    print result

Calculating the result changes the state assignable; it corresponds to the snd field in Haskell's (a, State) tuple. And the fst field of the tuple is the newly constructed tree, carrying a numbering alongside its data in its leaves.

These variants are functionally equivalent.

Perhaps you've heard the catch-phrase about monadic bind being a "programmable semicolon". Here the meaning of it is clear: it defines the "function call protocol" so to speak, that we use the first returned value as a calculated result, and the second returned value as the updated state, which we pass along to the next calculation, so it gets to see the updated state.

This is the state-passing style of programming (essential for e.g. Prolog), making the state change explicit but having to manually take care of passing along the correct, updated state. Monads allow us to abstract this "wiring" of state passing from one calculation to the next, so it is done automatically for us, at the price of having to think in imperative style, and having this state become hidden, implicit again (like the state change is implicit in the imperative programming, which we wanted to eschew in the first place when switching to the functional programming...).

So all that the State monad is doing is to maintain for us this hidden state, and passing it along updated between the consecutive calculations. So it's nothing major after all.



回答2:

Specifically do n <- fresh. We haven't passed any argument to fresh, right?

Exactly. We are writing for an argument that will be passed to fresh when we, for instance, do something like apply (mlabel someTree) 5. A nice exercise that will help you to see more clearly what is going on is first writing mlabel with explicit (>>=) instead of do-notation, and then replacing (>>=) and return with what the Monad instance says that they are.



回答3:

The key thing to realise is that do notation gets translated into Monad functions, so

do n <- fresh
   return (Leaf (x,n))

is short for

fresh >>= (\n -> 
           return (Leaf (x,n))  )

and

do l' <- mlabel l
   r' <- mlabel r
   return (Node l' r')

is short for

mlabel l >>= (\l' -> 
              mlabel r >>= (\r' ->
                            return (Node l' r') ))

This will hopefully allow you to continue figuring out the code's meaning, but for more help, you should read up on the do notation for Monads.